Generate All the Things
In this post, we’ll look at one technique from property-based testing repertoire: full coverage / exhaustive testing. Specifically, we will learn how to conveniently enumerate any kind of combinatorial object without using recursion.
To start, let’s assume we have some algorithmic problem to solve. For example, we want to sort an array of numbers:
fn sort(xs: &mut [u32]) {
...
}
To test that the sort
function works, we can write a
bunch of example-based test cases. This approach has two flaws:
- Generating examples by hand is time consuming.
- It might be hard to come up with interesting examples — any edge cases we’ve thought about is probably already handled in the code. We want to find cases which we didn’t think of before.
A better approach is randomized testing: just generate a random array and check that it is sorted:
fn naive_randomized_testing() {
let mut rng = rand::thread_rng();
for _ in 0..100_000 {
let n: usize = rng.gen_range(0..1_000);
let mut xs: Vec<u32> =
std::iter::repeat_with(|| rng.gen()).take(n).collect();
sort(&mut xs);
for i in 1..xs.len() {
assert!(xs[i - 1] <= xs[i]);
}
}
}
Here, we generated one hundred thousand completely random test cases!
Sadly, the result might actually be worse than a small set of
hand-picked examples. The problem here is that, if you pick an array
completely at random (sample uniformly), it will be a rather ordinary
array. In particular, given that the elements are arbitrary u32
numbers, it’s highly unlikely that we generate an array
with at least some equal elements. And when I write quick sort, I
always have that nasty bug that it just loops infinitely when all elements are equal.
There are several fixes for the problem. The simplest one is to just make the sampling space smaller:
std::iter::repeat_with(|| rng.gen_range(0..10)).take(n).collect();
If we generate not an arbitrary u32
, but a number between
0 and 10, we’ll get some short arrays where all elements are equal.
Another trick is to use a property-based testing library, which comes
with some strategies for generating interesting sequences predefined.
Yet another approach is to combine property-based testing and coverage
guided fuzzing. When checking a particular example, we will collect
coverage information for this specific input. Given a set of inputs
with coverage info, we can apply targeted genetic algorithms to try to
cover more of the code. A particularly fruitful insight here is that
we don’t have to invent a novel structure-aware fuzzer for this. We
can take an existing fuzzer which emits a sequence of bytes, and use
those bytes as a sequence of random numbers to generate structured
input. Essentially, we say that the fuzzer is a random number
generator. That way, when the fuzzer flips bits in the raw bytes
array, it applies local semantically valid transformations to the
random data structure.
But this post isn’t about those techniques :) Instead, it is about the
idea of full coverage.
Most of the bugs involve small, tricky examples. If a sorting
routine breaks on some array with ten thousand elements it’s highly
likely that there’s a much smaller array (a handful of elements),
which exposes the same bug. So what we can do is to just generate every array of length at most n
with numbers up to
m
and exhaustively check them all:
fn exhaustive_testing() {
let n = 5;
let m = 5;
for xs in every_array(n, m) {
sort(&mut xs);
for i in 1..xs.len() {
assert!(xs[i - 1] <= xs[i]);
}
}
}
The problem here is that implementing every_array
is
tricky. It is one of those puzzlers you know how to solve, but which
are excruciatingly annoying to implement for the umpteenth time:
fn every_array(n: usize, m: u32) -> Vec<Vec<u32>> {
if n == 0 {
return vec![Vec::new()];
}
let mut res = Vec::new();
for xs in every_array(n - 1, m) {
for x in 0..=m {
let mut ys = xs.clone();
ys.push(x);
res.push(ys)
}
}
res
}
What’s more, for algorithms you often need to generate permutations, combinations and subsets, and they all have similar simple but tricky recursive solutions.
Yesterday I needed to generate a sequence of up to n
segments with integer coordinates up to m
, which finally
pushed me to realize that there’s a relatively simple way to
exhaustively enumerate arbitrary combinatorial objects. I don’t recall
seeing it anywhere else, which is surprising, as the technique seems
rather elegant.
Let’s look again at how we generate a random array:
let l: usize = rng.gen_range(0..l);
let mut xs: Vec<u32> =
std::iter::repeat_with(|| rng.gen(..m)).take(m).collect();
This is definitely much more straightforward than the every_array
function above, although it does sort-of the same
thing. The trick is to take this “generate a random thing”
code and just make it generate every thing instead. In the
above code, we base decisions on random numbers. Specifically, an
input sequence of random numbers generates one element in the search
space. If we enumerate all sequences of random numbers, we then
explore the whole space.
Essentially, we’ll rig the rng
to not be random, but
instead to emit all finite sequences of numbers. By writing a single
generator of such sequences, we gain an ability to enumerate arbitrary
objects. As we are interested in generating all “small” objects, we
always pass an upper bound when asking for a “random” number. We can
use the bounds to enumerate only the sequences which fit under them.
So, the end result will look like this:
fn for_every_array() {
let n = 5;
let m = 4;
let mut g = Gen::new();
while !g.done() {
let l = g.gen(n) as usize;
let xs: Vec<_> =
std::iter::repeat_with(|| g.gen(m)).take(l).collect::<_>();
// `xs` enumerates all arrays
}
}
The implementation of Gen
is relatively straightforward.
On each iteration, we will remember the sequence of numbers we
generated together with bounds the user requested, something like
this:
value: 3 1 4 4
bound: 5 4 4 4
To advance to the next iteration, we will find the smallest sequence
of values which is larger than the current one, but still satisfies
all the bounds. “Smallest” means that we’ll try to increment the
rightmost number. In the above example, the last two fours already
match the bound, so we can’t increment them. However, we can
increment one to get 3 2 4 4
. This isn’t the smallest
sequence though, 3 2 0 0
would be smaller. So, after
incrementing the rightmost number we can increment, we zero the rest.
Here’s the full implementation:
struct Gen {
started: bool,
v: Vec<(u32, u32)>,
p: usize,
}
impl Gen {
fn new() -> Gen {
Gen { started: false, v: Vec::new(), p: 0 }
}
fn done(&mut self) -> bool {
if !self.started {
self.started = true;
return false;
}
for i in (0..self.v.len()).rev() {
if self.v[i].0 < self.v[i].1 {
self.v[i].0 += 1;
self.v.truncate(i + 1);
self.p = 0;
return false;
}
}
true
}
fn gen(&mut self, bound: u32) -> u32 {
if self.p == self.v.len() {
self.v.push((0, 0));
}
self.p += 1;
self.v[self.p - 1].1 = bound;
self.v[self.p - 1].0
}
}
Some notes:
-
We need
start
field to track the first iteration, and to makewhile !g.done()
syntax work. It’s a bit more natural to removestart
and use ado { } while !g.done()
loop, but it’s not available in Rust. -
v
stores(value, bound)
pairs. -
p
tracks the current position in the middle of the iteration. -
v
is conceptually an infinite vector with finite number of non-zero elements. So, whenp
gets past then end ofv
, we just materialize the implicit zero by pushing it ontov
. -
As we store zeros implicitly anyway, we can just truncate the vector
in
done
instead of zeroing-out the elements after the incremented one. -
Somewhat unusually, the bounds are treated inclusively. This removes
the panic when
bound
is zero, and allows to generate a full set of numbers viagen(u32::MAX)
.
Let’s see how our gen
fairs for generating random arrays
of length at most n
. We’ll count how many distinct cases
were covered:
fn gen_arrays() {
let n = 5;
let m = 4;
let expected_total =
(0..=n).map(|l| (m + 1).pow(l)).sum::<u32>();
let mut total = 0;
let mut all = HashSet::new();
let mut g = Gen::new();
while !g.done() {
let l = g.gen(n) as usize;
let xs: Vec<_> =
std::iter::repeat_with(|| g.gen(m)).take(l).collect::<_>();
all.insert(xs);
total += 1
}
assert_eq!(all.len(), total);
assert_eq!(expected_total, total as u32)
}
This test passes. That is, the gen
approach for this case
is both exhaustive (it generates all arrays) and efficient (each array
is generated once).
As promised in the post’s title, let’s now generate all the things.
First case: there should be only one nothing (that’s the reason why we
need start
):
fn gen_nothing() {
let expected_total = 1;
let mut total = 0;
let mut g = Gen::new();
while !g.done() {
total += 1;
}
assert_eq!(expected_total, total)
}
Second case: we expect to see n
numbers and n*2
ordered pairs of numbers.
fn gen_number() {
let n = 5;
let expected_total = n + 1;
let mut total = 0;
let mut all = HashSet::new();
let mut g = Gen::new();
while !g.done() {
let a = g.gen(n);
all.insert(a);
total += 1;
}
assert_eq!(expected_total, total);
assert_eq!(expected_total, all.len() as u32);
}
fn gen_number_pair() {
let n = 5;
let expected_total = (n + 1) * (n + 1);
let mut total = 0;
let mut all = HashSet::new();
let mut g = Gen::new();
while !g.done() {
let a = g.gen(n);
let b = g.gen(n);
all.insert((a, b));
total += 1;
}
assert_eq!(expected_total, total);
assert_eq!(expected_total, all.len() as u32);
}
Third case: we expect to see n * (n - 1) / 2
unordered
pairs of numbers. This one is interesting — here, our second decision
is based on the first one, but we still enumerate all the cases
efficiently (without duplicates). (Aside: did you ever realise that
the number of ways to pick two objects out of n
is equal
to the sum of first n
natural numbers?)
fn gen_number_combination() {
let n = 5;
let expected_total = n * (n + 1) / 2;
let mut total = 0;
let mut all = HashSet::new();
let mut g = Gen::new();
while !g.done() {
let a = g.gen(n - 1);
let b = a + 1 + g.gen(n - a - 1);
all.insert((a, b));
total += 1;
}
assert_eq!(expected_total, total);
assert_eq!(expected_total, all.len() as u32);
}
We’ve already generated all arrays, so let’s try to create all permutations. Still efficient:
fn gen_permutations() {
let n = 5;
let expected_total = (1..=n).product::<u32>();
let mut total = 0;
let mut all = HashSet::new();
let mut g = Gen::new();
while !g.done() {
let mut candidates: Vec<i32> = (1..=n).collect();
let mut permutation = Vec::new();
for _ in 0..n {
let idx = g.gen(candidates.len() as u32 - 1);
permutation.push(candidates.remove(idx as usize));
}
all.insert(permutation);
total += 1;
}
assert_eq!(expected_total, total);
assert_eq!(expected_total, all.len() as u32);
}
Subsets:
fn gen_subset() {
let n = 5;
let expected_total = 1 << n;
let mut total = 0;
let mut all = HashSet::new();
let mut g = Gen::new();
while !g.done() {
let s: Vec<_> = (0..n).map(|_| g.gen(1) == 1).collect();
all.insert(s);
total += 1;
}
assert_eq!(expected_total, total);
assert_eq!(expected_total, all.len() as u32);
}
Combinations:
fn gen_combinations() {
let n = 5;
let m = 3;
let fact = |n: u32| -> u32 { (1..=n).product() };
let expected_total = fact(n) / (fact(m) * fact(n - m));
let mut total = 0;
let mut all = HashSet::new();
let mut g = Gen::new();
while !g.done() {
let mut candidates: Vec<u32> = (1..=n).collect();
let mut combination = BTreeSet::new();
for _ in 0..m {
let idx = g.gen(candidates.len() as u32 - 1);
combination.insert(candidates.remove(idx as usize));
}
all.insert(combination);
total += 1;
}
assert_eq!(expected_total, total);
assert_eq!(expected_total, all.len() as u32);
}
Now, this one actually fails — while this code generates all combinations, some combinations are generated more than once. Specifically, what we are generating here are k-permutations (combinations with significant order of elements). While this is not efficient, this is OK for the purposes of exhaustive testing (as we still generate any combination). Nonetheless, there’s an efficient version as well:
let mut combination = BTreeSet::new();
for c in 1..=n {
if combination.len() as u32 == m {
break;
}
if combination.len() as u32 + (n - c + 1) == m {
combination.extend(c..=n);
break;
}
if g.gen(1) == 1 {
combination.insert(c);
}
}
I think this covers all standard combinatorial structures. What’s interesting, this approach works for non-standard structures as well. For example, for https://cses.fi/problemset/task/2168, the problem which started all this, I need to generate sequences of segments:
fn gen_segments() {
let n = 5;
let m = 6;
let mut total = 0;
let mut all = HashSet::new();
let mut g = Gen::new();
while !g.done() {
let l = g.gen(n);
let mut xs = Vec::new();
for _ in 0..l {
if m > 0 {
let l = g.gen(m - 1);
let r = l + 1 + g.gen(m - l - 1);
if !xs.contains(&(l, r)) {
xs.push((l, r))
}
}
}
all.insert(xs);
total += 1;
}
assert_eq!(all.len(), 2_593_942);
assert_eq!(total, 4_288_306);
}
Due to the .contains
check there are some duplicates, but
that’s not a problem as long as all sequences of segments are
generated. Additionally, examples are strictly ordered by their
complexity — earlier examples have fewer segments with smaller
coordinates. That means that the first example which fails a property
test is actually guaranteed to be the smallest counterexample! Nifty!
That’s all! Next time when you need to test something, consider if you
can just exhaustively enumerate all “sufficiently small” inputs. If
that’s feasible, you can either write the classical recursive
enumerator, or use this imperative Gen
thing.
Update(2021-11-28):
There are now Rust (crates.io link) and C++ (GitHub link) implementations. “Capturing the Future by Replaying the Past” is a related paper which includes the above technique as a special case of “simulate any monad by simulating delimited continuations via exceptions and replay” trick.
Balanced parentheses sequences:
fn gen_parenthesis() {
let n = 5;
let expected_total = 1 + 1 + 2 + 5 + 14 + 42;
let mut total = 0;
let mut all = HashSet::new();
let mut g = Gen::new();
while !g.done() {
let l = g.gen(n);
let mut s = String::new();
let mut bra = 0;
let mut ket = 0;
while ket < l {
if bra < l && (bra == ket || g.gen(1) == 1) {
s.push('(');
bra += 1;
} else {
s.push(')');
ket += 1;
}
}
all.insert(s);
total += 1;
}
assert_eq!(expected_total, total);
assert_eq!(expected_total, all.len() as u32);
}