Generate All the Things

In this post, well look at one technique from property-based testing repertoire: full coverage / exhaustive testing. Specifically, we will learn how to conveniently enumerate any kind of combinatorial object without using recursion.

To start, lets assume we have some algorithmic problem to solve. For example, we want to sort an array of numbers:

fn sort(xs: &mut [u32]) {
    ...
}

To test that the sort function works, we can write a bunch of example-based test cases. This approach has two flaws:

A better approach is randomized testing: just generate a random array and check that it is sorted:

#[test]
fn naive_randomized_testing() {
  let mut rng = rand::thread_rng();
  for _ in 0..100_000 {
    let n: usize = rng.gen_range(0..1_000);
    let mut xs: Vec<u32> =
      std::iter::repeat_with(|| rng.gen()).take(n).collect();

    sort(&mut xs);

    for i in 1..xs.len() {
      assert!(xs[i - 1] <= xs[i]);
    }
  }
}

Here, we generated one hundred thousand completely random test cases!

Sadly, the result might actually be worse than a small set of hand-picked examples. The problem here is that, if you pick an array completely at random (sample uniformly), it will be a rather ordinary array. In particular, given that the elements are arbitrary u32 numbers, its highly unlikely that we generate an array with at least some equal elements. And when I write quick sort, I always have that nasty bug that it just loops infinitely when all elements are equal.

There are several fixes for the problem. The simplest one is to just make the sampling space smaller:

std::iter::repeat_with(|| rng.gen_range(0..10)).take(n).collect();

If we generate not an arbitrary u32, but a number between 0 and 10, well get some short arrays where all elements are equal. Another trick is to use a property-based testing library, which comes with some strategies for generating interesting sequences predefined. Yet another approach is to combine property-based testing and coverage guided fuzzing. When checking a particular example, we will collect coverage information for this specific input. Given a set of inputs with coverage info, we can apply targeted genetic algorithms to try to cover more of the code. A particularly fruitful insight here is that we dont have to invent a novel structure-aware fuzzer for this. We can take an existing fuzzer which emits a sequence of bytes, and use those bytes as a sequence of random numbers to generate structured input. Essentially, we say that the fuzzer is a random number generator. That way, when the fuzzer flips bits in the raw bytes array, it applies local semantically valid transformations to the random data structure.

But this post isnt about those techniques :) Instead, it is about the idea of full coverage. Most of the bugs involve small, tricky examples. If a sorting routine breaks on some array with ten thousand elements its highly likely that theres a much smaller array (a handful of elements), which exposes the same bug. So what we can do is to just generate every array of length at most n with numbers up to m and exhaustively check them all:

#[test]
fn exhaustive_testing() {
  let n = 5;
  let m = 5;
  for xs in every_array(n, m) {
    sort(&mut xs);

    for i in 1..xs.len() {
      assert!(xs[i - 1] <= xs[i]);
    }
  }
}

The problem here is that implementing every_array is tricky. It is one of those puzzlers you know how to solve, but which are excruciatingly annoying to implement for the umpteenth time:

fn every_array(n: usize, m: u32) -> Vec<Vec<u32>> {
  if n == 0 {
    return vec![Vec::new()];
  }

  let mut res = Vec::new();
  for xs in every_array(n - 1, m) {
    for x in 0..=m {
      let mut ys = xs.clone();
      ys.push(x);
      res.push(ys)
    }
  }

  res
}

Whats more, for algorithms you often need to generate permutations, combinations and subsets, and they all have similar simple but tricky recursive solutions.

Yesterday I needed to generate a sequence of up to n segments with integer coordinates up to m, which finally pushed me to realize that theres a relatively simple way to exhaustively enumerate arbitrary combinatorial objects. I dont recall seeing it anywhere else, which is surprising, as the technique seems rather elegant.


Lets look again at how we generate a random array:

let l: usize = rng.gen_range(0..l);
let mut xs: Vec<u32> =
  std::iter::repeat_with(|| rng.gen(..m)).take(m).collect();

This is definitely much more straightforward than the every_array function above, although it does sort-of the same thing. The trick is to take this generate a random thing code and just make it generate every thing instead. In the above code, we base decisions on random numbers. Specifically, an input sequence of random numbers generates one element in the search space. If we enumerate all sequences of random numbers, we then explore the whole space.

Essentially, well rig the rng to not be random, but instead to emit all finite sequences of numbers. By writing a single generator of such sequences, we gain an ability to enumerate arbitrary objects. As we are interested in generating all small objects, we always pass an upper bound when asking for a random number. We can use the bounds to enumerate only the sequences which fit under them.

So, the end result will look like this:

#[test]
fn for_every_array() {
  let n = 5;
  let m = 4;

  let mut g = Gen::new();
  while !g.done() {
    let l = g.gen(n) as usize;
    let xs: Vec<_> =
      std::iter::repeat_with(|| g.gen(m)).take(l).collect::<_>();
    // `xs` enumerates all arrays
  }
}

The implementation of Gen is relatively straightforward. On each iteration, we will remember the sequence of numbers we generated together with bounds the user requested, something like this:

value:  3 1 4 4
bound:  5 4 4 4

To advance to the next iteration, we will find the smallest sequence of values which is larger than the current one, but still satisfies all the bounds.Smallest means that well try to increment the rightmost number. In the above example, the last two fours already match the bound, so we cant increment them. However, we can increment one to get 3 2 4 4. This isnt the smallest sequence though, 3 2 0 0 would be smaller. So, after incrementing the rightmost number we can increment, we zero the rest.

Heres the full implementation:

struct Gen {
  started: bool,
  v: Vec<(u32, u32)>,
  p: usize,
}

impl Gen {
  fn new() -> Gen {
    Gen { started: false, v: Vec::new(), p: 0 }
  }
  fn done(&mut self) -> bool {
    if !self.started {
      self.started = true;
      return false;
    }

    for i in (0..self.v.len()).rev() {
      if self.v[i].0 < self.v[i].1 {
        self.v[i].0 += 1;
        self.v.truncate(i + 1);
        self.p = 0;
        return false;
      }
    }

    true
  }
  fn gen(&mut self, bound: u32) -> u32 {
    if self.p == self.v.len() {
      self.v.push((0, 0));
    }
    self.p += 1;
    self.v[self.p - 1].1 = bound;
    self.v[self.p - 1].0
  }
}

Some notes:

Lets see how our gen fairs for generating random arrays of length at most n. Well count how many distinct cases were covered:

#[test]
fn gen_arrays() {
  let n = 5;
  let m = 4;
  let expected_total =
    (0..=n).map(|l| (m + 1).pow(l)).sum::<u32>();

  let mut total = 0;
  let mut all = HashSet::new();

  let mut g = Gen::new();
  while !g.done() {
    let l = g.gen(n) as usize;
    let xs: Vec<_> =
      std::iter::repeat_with(|| g.gen(m)).take(l).collect::<_>();

    all.insert(xs);
    total += 1
  }

  assert_eq!(all.len(), total);
  assert_eq!(expected_total, total as u32)
}

This test passes. That is, the gen approach for this case is both exhaustive (it generates all arrays) and efficient (each array is generated once).

As promised in the posts title, lets now generate all the things.

First case: there should be only one nothing (thats the reason why we need start):

#[test]
fn gen_nothing() {
  let expected_total = 1;

  let mut total = 0;
  let mut g = Gen::new();
  while !g.done() {
    total += 1;
  }
  assert_eq!(expected_total, total)
}

Second case: we expect to see n numbers and n*2 ordered pairs of numbers.

#[test]
fn gen_number() {
  let n = 5;
  let expected_total = n + 1;

  let mut total = 0;
  let mut all = HashSet::new();
  let mut g = Gen::new();
  while !g.done() {
    let a = g.gen(n);

    all.insert(a);
    total += 1;
  }

  assert_eq!(expected_total, total);
  assert_eq!(expected_total, all.len() as u32);
}

#[test]
fn gen_number_pair() {
  let n = 5;
  let expected_total = (n + 1) * (n + 1);

  let mut total = 0;
  let mut all = HashSet::new();
  let mut g = Gen::new();
  while !g.done() {
    let a = g.gen(n);
    let b = g.gen(n);

    all.insert((a, b));
    total += 1;
  }

  assert_eq!(expected_total, total);
  assert_eq!(expected_total, all.len() as u32);
}

Third case: we expect to see n * (n - 1) / 2 unordered pairs of numbers. This one is interesting here, our second decision is based on the first one, but we still enumerate all the cases efficiently (without duplicates). (Aside: did you ever realise that the number of ways to pick two objects out of n is equal to the sum of first n natural numbers?)

#[test]
fn gen_number_combination() {
  let n = 5;
  let expected_total = n * (n + 1) / 2;

  let mut total = 0;
  let mut all = HashSet::new();
  let mut g = Gen::new();
  while !g.done() {
    let a = g.gen(n - 1);
    let b = a + 1 + g.gen(n - a - 1);
    all.insert((a, b));
    total += 1;
  }

  assert_eq!(expected_total, total);
  assert_eq!(expected_total, all.len() as u32);
}

Weve already generated all arrays, so lets try to create all permutations. Still efficient:

#[test]
fn gen_permutations() {
  let n = 5;
  let expected_total = (1..=n).product::<u32>();

  let mut total = 0;
  let mut all = HashSet::new();
  let mut g = Gen::new();
  while !g.done() {
    let mut candidates: Vec<i32> = (1..=n).collect();
    let mut permutation = Vec::new();
    for _ in 0..n {
      let idx = g.gen(candidates.len() as u32 - 1);
      permutation.push(candidates.remove(idx as usize));
    }

    all.insert(permutation);
    total += 1;
  }

  assert_eq!(expected_total, total);
  assert_eq!(expected_total, all.len() as u32);
}

Subsets:

#[test]
fn gen_subset() {
    let n = 5;
    let expected_total = 1 << n;

    let mut total = 0;
    let mut all = HashSet::new();
    let mut g = Gen::new();
    while !g.done() {
        let s: Vec<_> = (0..n).map(|_| g.gen(1) == 1).collect();

        all.insert(s);
        total += 1;
    }

    assert_eq!(expected_total, total);
    assert_eq!(expected_total, all.len() as u32);
}

Combinations:

#[test]
fn gen_combinations() {
    let n = 5;
    let m = 3;
    let fact = |n: u32| -> u32 { (1..=n).product() };
    let expected_total = fact(n) / (fact(m) * fact(n - m));

    let mut total = 0;
    let mut all = HashSet::new();
    let mut g = Gen::new();
    while !g.done() {
        let mut candidates: Vec<u32> = (1..=n).collect();
        let mut combination = BTreeSet::new();
        for _ in 0..m {
            let idx = g.gen(candidates.len() as u32 - 1);
            combination.insert(candidates.remove(idx as usize));
        }

        all.insert(combination);
        total += 1;
    }

    assert_eq!(expected_total, total);
    assert_eq!(expected_total, all.len() as u32);
}

Now, this one actually fails while this code generates all combinations, some combinations are generated more than once. Specifically, what we are generating here are k-permutations (combinations with significant order of elements). While this is not efficient, this is OK for the purposes of exhaustive testing (as we still generate any combination). Nonetheless, theres an efficient version as well:

let mut combination = BTreeSet::new();
for c in 1..=n {
  if combination.len() as u32 == m {
    break;
  }
  if combination.len() as u32 + (n - c + 1) == m {
    combination.extend(c..=n);
    break;
  }
  if g.gen(1) == 1 {
    combination.insert(c);
  }
}

I think this covers all standard combinatorial structures. Whats interesting, this approach works for non-standard structures as well. For example, for https://cses.fi/problemset/task/2168, the problem which started all this, I need to generate sequences of segments:

#[test]
fn gen_segments() {
  let n = 5;
  let m = 6;

  let mut total = 0;
  let mut all = HashSet::new();
  let mut g = Gen::new();
  while !g.done() {
    let l = g.gen(n);

    let mut xs = Vec::new();
    for _ in 0..l {
      if m > 0 {
        let l = g.gen(m - 1);
        let r = l + 1 + g.gen(m - l - 1);
        if !xs.contains(&(l, r)) {
          xs.push((l, r))
        }
      }
    }

    all.insert(xs);
    total += 1;
  }
  assert_eq!(all.len(), 2_593_942);
  assert_eq!(total, 4_288_306);
}

Due to the .contains check there are some duplicates, but thats not a problem as long as all sequences of segments are generated. Additionally, examples are strictly ordered by their complexity earlier examples have fewer segments with smaller coordinates. That means that the first example which fails a property test is actually guaranteed to be the smallest counterexample! Nifty!

Thats all! Next time when you need to test something, consider if you can just exhaustively enumerate all sufficiently small inputs. If thats feasible, you can either write the classical recursive enumerator, or use this imperative Gen thing.

Update(2021-11-28):

There are now Rust (crates.io link) and C++ (GitHub link) implementations. Capturing the Future by Replaying the Past is a related paper which includes the above technique as a special case of simulate any monad by simulating delimited continuations via exceptions and replay trick.

Balanced parentheses sequences:

#[test]
fn gen_parenthesis() {
  let n = 5;
  let expected_total = 1 + 1 + 2 + 5 + 14 + 42;

  let mut total = 0;
  let mut all = HashSet::new();
  let mut g = Gen::new();
  while !g.done() {
    let l = g.gen(n);
    let mut s = String::new();
    let mut bra = 0;
    let mut ket = 0;
    while ket < l {
      if bra < l && (bra == ket || g.gen(1) == 1) {
        s.push('(');
        bra += 1;
      } else {
        s.push(')');
        ket += 1;
      }
    }

    all.insert(s);
    total += 1;
  }

  assert_eq!(expected_total, total);
  assert_eq!(expected_total, all.len() as u32);
}