Properly Testing Concurrent Data Structures

Theres a fascinating Rust library, loom, which can be used to thoroughly test lock-free data structures. I always wanted to learn how it works. I still do! But recently I accidentally implemented a small toy which, I think, contains some of the looms ideas, and it seems worthwhile to write about that. The goal here isnt to teach you what you should be using in practice (if you need that, go read looms docs), but rather to derive a couple of neat ideas from first principles.

One, Two, Three, Two

As usual, we need the simplest possible model program to mess with. The example we use comes from this excellent article. Behold, a humble (and broken) concurrent counter:

use std::sync::atomic::{
  AtomicU32,
  Ordering::SeqCst,
};

#[derive(Default)]
pub struct Counter {
  value: AtomicU32,
}

impl Counter {
  pub fn increment(&self) {
    let value = self.value.load(SeqCst);
    self.value.store(value + 1, SeqCst);
  }

  pub fn get(&self) -> u32 {
    self.value.load(SeqCst)
  }
}

The bug is obvious here the increment is not atomic. But what is the best test we can write to expose it?

Trivial Test

The simplest idea that comes to mind is to just hammer the same counter from multiple threads and check the result at the end;

#[test]
fn threaded_test() {
  let counter = Counter::default();

  let thread_count = 100;
  let increment_count = 100;

  std::thread::scope(|scope| {
    for _ in 0..thread_count {
      scope.spawn(|| {
        for _ in 0..increment_count {
          counter.increment()
        }
      });
    }
  });

  assert_eq!(counter.get(), thread_count * increment_count);
}

This fails successfully:

thread 'counter::trivial' panicked:
assertion `left == right` failed
  left: 9598
 right: 10000

But I wouldnt call this test satisfactory it very much depends on the timing, so you cant reproduce it deterministically and you cant debug it. You also cant minimize it if you reduce the number of threads and increments, chances are the test passes by luck!

PBT

Of course the temptation is to apply property based testing here! The problem almost fits: we have easy-to-generate input (the sequence of increments spread over several threads), a good property to check (result of concurrent increments is identical to that of sequential execution) and the desire to minimize the test.

But just how can we plug threads into a property-based test?

PBTs are great for testing state machines. You can run your state machine through a series of steps where at each step a PBT selects an arbitrary next action to apply to the state:

#[test]
fn state_machine_test() {
  arbtest::arbtest(|rng| {
    // This is our state machine!
    let mut state: i32 = 0;

    // We'll run it for up to 100 steps.
    let step_count: usize = rng.int_in_range(0..=100)?;

    for _ in 0..step_count {
      // At each step, we flip a coin and
      // either increment or decrement.
      match *rng.choose(&["inc", "dec"])? {
        "inc" => state += 1,
        "dec" => state -= 1,
        _ => unreachable!(),
      }
    }
    Ok(())
  });
}

And it feels like we should be able to apply the same technique here. At every iteration, pick a random thread and make it do a single step. If you can step the threads manually, it should be easy to maneuver one thread in between load&store of a different thread.

But we cant step through threads! Or can we?

Simple Instrumentation

Ok, lets fake it until we make it! Lets take a look at the buggy increment method:

pub fn increment(&self) {
  let value = self.value.load(SeqCst);
  self.value.store(value + 1, SeqCst);
}

Ideally, wed love to be able to somehow pause the thread in-between atomic operations. Something like this:

pub fn increment(&self) {
  pause();
  let value = self.value.load(SeqCst);
  pause();
  self.value.store(value + 1, SeqCst);
  pause();
}

fn pause() {
    // ¯\_(ツ)_/¯
}

So lets start with implementing our own wrapper for AtomicU32 which includes calls to pause.

use std::sync::atomic::Ordering;

struct AtomicU32 {
  inner: std::sync::atomic::AtomicU32,
}

impl AtomicU32 {
  pub fn load(&self, ordering: Ordering) -> u32 {
    pause();
    let result = self.inner.load(ordering);
    pause();
    result
  }

  pub fn store(&self, value: u32, ordering: Ordering) {
    pause();
    self.inner.store(value, ordering);
    pause();
  }
}

fn pause() {
  // still no idea :(
}

Managed Threads API

One rule of a great API design is that you start by implement a single user of an API, to understand how the API should feel, and only then proceed to the actual implementation.

So, in the spirit of faking, lets just write a PBT using these pausable, managed threads, even if we still have no idea how to actually implement pausing.

We start with creating a counter and two managed threads. And we probably want to pass a reference to the counter to each of the threads:

let counter = Counter::default();
let t1 = managed_thread::spawn(&counter);
let t2 = managed_thread::spawn(&counter);

Now, we want to step through the threads:

while !rng.is_empty() {
  let coin_flip: bool = rng.arbitrary()?;
  if t1.is_paused() {
    if coin_flip {
      t1.unpause();
    }
  }
  if t2.is_paused() {
    if coin_flip {
      t2.unpause();
    }
  }
}

Or, refactoring this a bit to semantically compress:

let counter = Counter::default();
let t1 = managed_thread::spawn(&counter);
let t2 = managed_thread::spawn(&counter);
let threads = [t1, t2];

while !rng.is_empty() {
  for t in &mut threads {
    if t.is_paused() && rng.arbitrary()? {
      t.unpause()
    }
  }
}

That is, on each step of our state machine, we loop through all threads and unpause a random subset of them.

But besides pausing and unpausing, we need our threads to actually do something, to increment the counter. One idea is to mirror the std::spawn API and pass a closure in:

let t1 = managed_thread::spawn({
  let counter = &counter;
  move || {
    for _ in 0..100 {
      counter.increment();
    }
  }
});

But as these are managed threads, and we want to control them from our tests, lets actually go all the way there and give the controlling thread an ability to change the code running in a managed thread. That is, well start managed threads without a main function, and provide an API to execute arbitrary closures in the context of this by-default inert thread (universal server anyone?):

let counter = Counter::default();

// We pass the state, &counter, in, but otherwise the thread is inert.
let t = managed_thread::spawn(&counter);

// But we can manually poke it:
t.submit(|thread_state: &Counter| thread_state.increment());
t.submit(|thread_state: &Counter| thread_state.increment());

Putting everything together, we get a nice-looking property test:

#[cfg(test)]
use managed_thread::AtomicU32;
#[cfg(not(test))]
use std::sync::atomic::AtomicU32;

#[derive(Default)]
pub struct Counter {
  value: AtomicU32,
}

impl Counter {
  // ...
}

#[test]
fn test_counter() {
  arbtest::arbtest(|rng| {
    // Our "Concurrent System Under Test".
    let counter = Counter::default();

    // The sequential model we'll compare the result against.
    let counter_model: u32 = 0;

    // Two managed threads which we will be stepping through
    // manually.
    let t1 = managed_thread::spawn(&counter);
    let t2 = managed_thread::spawn(&counter);
    let threads = [t1, t2];

    // Bulk of the test: in a loop, flip a coin and advance
    // one of the threads.
    while !rng.is_empty() {
      for t in &mut [t1, t2] {
        if rng.arbitrary() {
          if t.is_paused() {
            t.unpause()
          } else {
            // Standard "model equivalence" property: apply
            // isomorphic actions to the system and its model.
            t.submit(|c| c.increment());
            counter_model += 1;
          }
        }
      }
    }

    for t in threads {
      t.join();
    }

    assert_eq!(counter_model, counter.get());

    Ok(())
  });
}

Now, if only we could make this API work Remember, our pause implementation is a shrug emoji!

At this point, you might be mightily annoyed at me for this rhetorical device where I pretend that I dont know the answer. No need for annoyance when writing this code for the first time, I traced exactly these steps I realized that I need a pausing AtomicU32 so I did that (with dummy pause calls), then I played with the API I wanted to have, ending at roughly this spot, without yet knowing how I would make it work or, indeed, if it is possible at all.

Well, if I am being honest, there is a bit of up-front knowledge here. I dont think we can avoid spawning real threads here, unless we do something really cursed with inline assembly. When something calls that pause() function, and we want it to stay paused until further notice, that just has to happen in a thread which maintains a stack separate from the stack of our test. And, if we are going to spawn threads, we might as well spawn scoped threads, so that we can freely borrow stack-local data. And to spawn a scope thread, you need a Scope parameter. So in reality well need one more level of indentation here:

    std::thread::scope(|scope| {
      let t1 = managed_thread::spawn(scope, &counter);
      let t2 = managed_thread::spawn(scope, &counter);
      let threads = [t1, t2];
      while !rng.is_empty() {
        for t in &mut [t1, t2] {
          // ...
        }
      }
    });

Managed Threads Implementation

Now, the fun part: how the heck are we going to make pausing and unpausing work? For starters, there clearly needs to be some communication between the main thread (t.unpause()) and the managed thread (pause()). And, because we dont want to change Counter API to thread some kind of test-only context, the context needs to be smuggled. So thread_local! it is. And this context is going to be shared between two threads, so it must be wrapped in an Arc.

struct SharedContext {
  // 🤷
}

thread_local! {
  static INSTANCE: RefCell<Option<Arc<SharedContext>>> =
    RefCell::new(None);
}

impl SharedContext {
  fn set(ctx: Arc<SharedContext>) {
    INSTANCE.with(|it| *it.borrow_mut() = Some(ctx));
  }

  fn get() -> Option<Arc<SharedContext>> {
    INSTANCE.with(|it| it.borrow().clone())
  }
}

As usual when using thread_local! or lazy_static!, it is convenient to immediately wrap it into better typed accessor functions. And, given that we are using an Arc here anyway, we can conveniently escape thread_locals with by cloning the Arc.

So now we finally can implement the global pause function (or at least can kick the proverbial can a little bit farther):

fn pause() {
  if let Some(ctx) = SharedContext::get() {
    ctx.pause()
  }
}

impl SharedContext {
  fn pause(&self) {
    // 😕
  }
}

Ok, what to do next? We somehow need to coordinate the control thread and the managed thread. And we need some sort of notification mechanism, so that the managed thread knows when it can continue. The most brute force solution here is a pair of a mutex protecting some state and a condition variable. Mutex guards the state that can be manipulated by either of the threads. Condition variable can be used to signal about the changes.

struct SharedContext {
  state: Mutex<State>,
  cv: Condvar,
}

struct State {
  // 🤡
}

Okay, it looks like I am running out of emojies here. Theres no more layers of indirection or infrastructure left, we need to write some real code that actually does do that pausing thing. So lets say that the state is tracking, well, the state of our managed thread, which can be either running or paused:

#[derive(PartialEq, Eq, Default)]
enum State {
  #[default]
  Running,
  Paused,
}

And then the logic of the pause function flip the state from Running to Paused, notify the controlling thread that we are Paused, and wait until the controlling thread flips our state back to Running:

impl SharedContext {
  fn pause(&self) {
    let mut guard = self.state.lock().unwrap();
    assert_eq!(*guard, State::Running);
    *guard = State::Paused;
    self.cv.notify_all();
    while *guard == State::Paused {
      guard = self.cv.wait(guard).unwrap();
    }
    assert_eq!(*guard, State::Running);
  }
}

Aside: Rusts API for condition variables is beautiful. Condvars are tricky, and I didnt really understood them until seeing the signatures of Rust functions. Notice how the wait function takes a mutex guard as an argument, and returns a mutex guard. This protects you from the logical races and guides you towards the standard pattern of using condvars:

First, you lock the mutex around the shared state. Then, you inspect whether the state is what you need. If thats the case, great, you do what you wanted to do and unlock the mutex. If not, then, while still holding the mutex, you wait on the condition variable. Which means that the mutex gets unlocked, and other threads get the chance to change the shared state. When they do change it, and notify the condvar, your thread wakes up, and it gets the locked mutex back (but the state now is different). Due to the possibility of spurious wake-ups, you need to double check the state and be ready to loop back again to waiting.

Naturally, theres a helper that encapsulates this whole pattern:

impl SharedContext {
  fn pause(&self) {
    let mut guard = self.state.lock().unwrap();
    assert_eq!(*guard, State::Running);
    *guard = State::Paused;
    self.cv.notify_all();
    guard = self
      .cv
      .wait_while(guard, |state| *state == State::Paused)
      .unwrap();
    assert_eq!(*guard, State::Running)
  }
}

Ok, this actually does look like a reasonable implementation of pause. Lets move on to managed_thread::spawn:

fn spawn<'scope, T: 'scope + Send>(
  scope: &Scope<'scope, '_>,
  state: T,
) {
  // ? ? ?? ??? ?????
}

Theres a bunch of stuff that needs to happen here:

  • As we have established, we are going to spawn a (scoped) thread, so we need the scope parameter with its three lifetimes. I dont know how it works, so I am just going by the docs here!
  • We are going to return some kind of handle, which we can use to pause and unpause our managed thread. And that handle is going to be parametrized over the same 'scope lifetime, because itll hold onto the actual join handle.
  • We are going to pass the generic state to our new thread, and that state needs to be Send, and bounded by the same lifetime as our scoped thread.
  • Inside, we are going to spawn a thread for sure, and well need to setup the INSTANCE thread local on that thread.
  • And it would actually be a good idea to stuff a reference to that SharedContext into the handle we return.

A bunch of stuff, in other words. Lets do it:

struct ManagedHandle<'scope> {
  inner: std::thread::ScopedJoinHandle<'scope, ()>,
  ctx: Arc<SharedContext>,
}

fn spawn<'scope, T: 'scope + Send>(
  scope: &'scope Scope<'scope, '_>,
  state: T,
) -> ManagedHandle<'scope> {
  let ctx: Arc<SharedContext> = Default::default();
  let inner = scope.spawn({
    let ctx = Arc::clone(&ctx);
    move || {
      SharedContext::set(ctx);
      drop(state); // TODO: ¿
    }
  });
  ManagedHandle { inner, ctx }
}

The essentially no-op function we spawn looks sus. Well fix later! Lets try to implement is_paused and unpause first! They should be relatively straightforward. For is_paused, we just need to lock the mutex and check the state:

impl ManagedHandle<'_> {
  pub fn is_paused(&self,) -> bool {
    let guard = self.ctx.state.lock().unwrap();
    *guard == State::Paused
  }
}

For unpause, we should additionally flip the state back to Running and notify the other thread:

impl ManagedHandle<'_> {
  pub fn unpause(&self) {
    let mut guard = self.ctx.state.lock().unwrap();
    assert_eq!(*guard, State::Paused);
    *guard = State::Running;
    self.ctx.cv.notify_all();
  }
}

But I think thats not quite correct. Can you see why?

With this implementation, after unpause, the controlling and the managed threads will be running concurrently. And that can lead to non-determinism, the very problem we are trying to avoid here! In particular, if you call is_paused right after you unpause the thread, youll most likely get false back, as the other thread will still be running. But it might also hit the next pause call, so, depending on timing, you might also get true.

What we want is actually completely eliminating all unmanaged concurrency. That means that at any given point in time, only one thread (controlling or managed) should be running. So the right semantics for unpause is to unblock the managed thread, and then block the controlling thread until the managed one hits the next pause!

impl ManagedHandle<'_> {
  pub fn unpause(&self) {
    let mut guard = self.ctx.state.lock().unwrap();
    assert_eq!(*guard, State::Paused);
    *guard = State::Running;
    self.ctx.cv.notify_all();
    guard = self
      .ctx
      .cv
      .wait_while(guard, |state| *state == State::Running)
      .unwrap();
  }
}

At this point we can spawn a managed thread, pause it and resume. But right now it doesnt do anything. Next step is implementing that idea where the controlling thread can directly send an arbitrary closure to the managed one to make it do something:

impl<'scope> ManagedHandle<'scope> {
  pub fn submit<F: FnSomething>(&self, f: F)
}

Lets figure this FnSomething bound! We are going to yeet this f over to the managed thread and run it there once, so it is FnOnce. It is crossing thread-boundary, so it needs to be + Send. And, because we are using scoped threads, it doesnt have to be 'static, just 'scope is enough. Moreover, in that managed thread the f will have exclusive access to threads state, T. So we have:

impl<'scope> ManagedHandle<'scope> {
  pub fn submit<F: FnOnce(&mut T) + Send + 'scope>(self, f: F)
}

Implementing this is a bit tricky. First, well need some sort of the channel to actually move the function. Then, similarly to the unpause logic, well need synchronization to make sure that the control thread doesnt resume until the managed thread starts running f and hits a pause (or maybe completes f). And well also need a new state, Ready, because now there are two different reasons why a managed thread might be blocked it might wait for an unpause event, or it might wait for the next f to execute. This is the new code:

#[derive(Default)]
enum State {
  #[default]
  Ready,
  Running,
  Paused,
}

struct ManagedHandle<'scope, T> {
  inner: std::thread::ScopedJoinHandle<'scope, ()>,
  ctx: Arc<SharedContext>,
  sender: mpsc::Sender<Box<dyn FnOnce(&mut T) + 'scope + Send>>,
}

pub fn spawn<'scope, T: 'scope + Send>(
  scope: &'scope Scope<'scope, '_>,
  mut state: T,
) -> ManagedHandle<'scope, T> {
  let ctx: Arc<SharedContext> = Default::default();
  let (sender, receiver) =
    mpsc::channel::<Box<dyn FnOnce(&mut T) + 'scope + Send>>();
  let inner = scope.spawn({
    let ctx = Arc::clone(&ctx);
    move || {
      SharedContext::set(Arc::clone(&ctx));

      for f in receiver {
        f(&mut state);

        let mut guard = ctx.state.lock().unwrap();
        assert_eq!(*guard, State::Running);
        *guard = State::Ready;
        ctx.cv.notify_all()
      }
    }
  });
  ManagedHandle { inner, ctx, sender }
}

impl<'scope, T> ManagedHandle<'scope, T> {
  pub fn submit<F: FnOnce(&mut T) + Send + 'scope>(&self, f: F) {
    let mut guard = self.ctx.state.lock().unwrap();
    assert_eq!(*guard, State::Ready);
    *guard = State::Running;
    self.sender.send(Box::new(f)).unwrap();
    guard = self
      .ctx
      .cv
      .wait_while(guard, |state| *state == State::Running)
      .unwrap();
  }
}

The last small piece of the puzzle is the join function. Its almost standard! First we close our side of the channel. This serves as a natural stop signal for the other thread, so it exits. Which in turn allows us to join it. The small wrinkle here is that the thread might be paused when we try to join it, so we need to unpause it beforehand:

impl<'scope, T> ManagedHandle<'scope, T> {
  pub fn join(self) {
    while self.is_paused() {
      self.unpause();
    }
    drop(self.sender);
    self.inner.join().unwrap();
  }
}

Thats it! Lets put everything together!

Helper library, managed_thread.rs:

use std::{
  cell::RefCell,
  sync::{atomic::Ordering, mpsc, Arc, Condvar, Mutex},
  thread::Scope,
};

#[derive(Default)]
pub struct AtomicU32 {
  inner: std::sync::atomic::AtomicU32,
}

impl AtomicU32 {
  pub fn load(&self, ordering: Ordering) -> u32 {
    pause();
    let result = self.inner.load(ordering);
    pause();
    result
  }

  pub fn store(&self, value: u32, ordering: Ordering) {
    pause();
    self.inner.store(value, ordering);
    pause();
  }
}

fn pause() {
  if let Some(ctx) = SharedContext::get() {
    ctx.pause()
  }
}

#[derive(Default)]
struct SharedContext {
  state: Mutex<State>,
  cv: Condvar,
}

#[derive(Default, PartialEq, Eq, Debug)]
enum State {
  #[default]
  Ready,
  Running,
  Paused,
}

thread_local! {
  static INSTANCE: RefCell<Option<Arc<SharedContext>>> =
    RefCell::new(None);
}

impl SharedContext {
  fn set(ctx: Arc<SharedContext>) {
    INSTANCE.with(|it| *it.borrow_mut() = Some(ctx));
  }

  fn get() -> Option<Arc<SharedContext>> {
    INSTANCE.with(|it| it.borrow().clone())
  }

  fn pause(&self) {
    let mut guard = self.state.lock().unwrap();
    assert_eq!(*guard, State::Running);
    *guard = State::Paused;
    self.cv.notify_all();
    guard = self
      .cv
      .wait_while(guard, |state| *state == State::Paused)
      .unwrap();
    assert_eq!(*guard, State::Running)
  }
}

pub struct ManagedHandle<'scope, T> {
  inner: std::thread::ScopedJoinHandle<'scope, ()>,
  sender: mpsc::Sender<Box<dyn FnOnce(&mut T) + 'scope + Send>>,
  ctx: Arc<SharedContext>,
}

pub fn spawn<'scope, T: 'scope + Send>(
  scope: &'scope Scope<'scope, '_>,
  mut state: T,
) -> ManagedHandle<'scope, T> {
  let ctx: Arc<SharedContext> = Default::default();
  let (sender, receiver) =
    mpsc::channel::<Box<dyn FnOnce(&mut T) + 'scope + Send>>();
  let inner = scope.spawn({
    let ctx = Arc::clone(&ctx);
    move || {
      SharedContext::set(Arc::clone(&ctx));
      for f in receiver {
        f(&mut state);
        let mut guard = ctx.state.lock().unwrap();
        assert_eq!(*guard, State::Running);
        *guard = State::Ready;
        ctx.cv.notify_all()
      }
    }
  });
  ManagedHandle { inner, ctx, sender }
}

impl<'scope, T> ManagedHandle<'scope, T> {
  pub fn is_paused(&self) -> bool {
    let guard = self.ctx.state.lock().unwrap();
    *guard == State::Paused
  }

  pub fn unpause(&self) {
    let mut guard = self.ctx.state.lock().unwrap();
    assert_eq!(*guard, State::Paused);
    *guard = State::Running;
    self.ctx.cv.notify_all();
    guard = self
      .ctx
      .cv
      .wait_while(guard, |state| *state == State::Running)
      .unwrap();
  }

  pub fn submit<F: FnOnce(&mut T) + Send + 'scope>(&self, f: F) {
    let mut guard = self.ctx.state.lock().unwrap();
    assert_eq!(*guard, State::Ready);
    *guard = State::Running;
    self.sender.send(Box::new(f)).unwrap();
    guard = self
      .ctx
      .cv
      .wait_while(guard, |state| *state == State::Running)
      .unwrap();
  }

  pub fn join(self) {
    while self.is_paused() {
      self.unpause();
    }
    drop(self.sender);
    self.inner.join().unwrap();
  }
}

System under test, not-exactly-atomic counter:

use std::sync::atomic::Ordering::SeqCst;

#[cfg(test)]
use managed_thread::AtomicU32;
#[cfg(not(test))]
use std::sync::atomic::AtomicU32;

#[derive(Default)]
pub struct Counter {
  value: AtomicU32,
}

impl Counter {
  pub fn increment(&self) {
    let value = self.value.load(SeqCst);
    self.value.store(value + 1, SeqCst);
  }

  pub fn get(&self) -> u32 {
    self.value.load(SeqCst)
  }
}

And the test itself:

#[test]
fn test_counter() {
  arbtest::arbtest(|rng| {
    eprintln!("begin trace");
    let counter = Counter::default();
    let mut counter_model: u32 = 0;

    std::thread::scope(|scope| {
      let t1 = managed_thread::spawn(scope, &counter);
      let t2 = managed_thread::spawn(scope, &counter);
      let mut threads = [t1, t2];

      while !rng.is_empty() {
        for (tid, t) in threads.iter_mut().enumerate() {
          if rng.arbitrary()? {
            if t.is_paused() {
              eprintln!("{tid}: unpause");
              t.unpause()
            } else {
              eprintln!("{tid}: increment");
              t.submit(|c| c.increment());
              counter_model += 1;
            }
          }
        }
      }

      for t in threads {
        t.join();
      }
      assert_eq!(counter_model, counter.get());

      Ok(())
    })
  });
}

Running it identifies a failure:

---- test_counter stdout ----
begin trace
0: increment
1: increment
0: unpause
1: unpause
1: unpause
0: unpause
0: unpause
1: unpause
0: unpause
0: increment
1: unpause
0: unpause
1: increment
0: unpause
0: unpause
1: unpause
0: unpause
thread 'test_counter' panicked at src/lib.rs:56:7:
assertion `left == right` failed
  left: 4
 right: 3

arbtest failed!
    Seed: 0x4fd7ddff00000020

Which is something we got like 5% into this article already, with normal threads! But theres more to this failure. First, it is reproducible. If I specify the same seed, I get the exact same interleaving:

#[test]
fn test_counter() {
  arbtest::arbtest(|rng| {
    eprintln!("begin trace");
    ...
  })
    .seed(0x71aafcd900000020);
}

And this is completely machine independent! If you specify this seed, youll get exact same interleaving. So, if I am having trouble debugging this, I can DM you this hex in Zulip, and youll be able to help out!

But theres more we dont need to debug this failure, we can minimize it!

#[test]
fn test_counter() {
  arbtest::arbtest(|rng| {
    eprintln!("begin trace");
    ...
  })
    .seed(0x71aafcd900000020)
    .minimize();
}

This gives me the following minimization trace:

begin trace
0: increment
1: increment
0: unpause
1: unpause
1: unpause
0: unpause
0: unpause
1: unpause
0: unpause
0: increment
1: unpause
0: unpause
1: increment
0: unpause
0: unpause
1: unpause
0: unpause
seed 0x4fd7ddff00000020, seed size 32, search time 106.00ns

begin trace
0: increment
1: increment
0: unpause
0: unpause
1: unpause
0: unpause
1: unpause
0: unpause
1: unpause
1: unpause
1: increment
seed 0x540c0c1c00000010, seed size 16, search time 282.16µs

begin trace
0: increment
1: increment
0: unpause
1: unpause
1: unpause
1: unpause
seed 0x084ca71200000008, seed size 8, search time 805.74µs

begin trace
0: increment
1: increment
0: unpause
1: unpause
seed 0x5699b19400000004, seed size 4, search time 1.44ms

begin trace
0: increment
1: increment
0: unpause
1: unpause
seed 0x4bb0ea5c00000002, seed size 2, search time 4.03ms

begin trace
0: increment
1: increment
0: unpause
1: unpause
seed 0x9c2a13a600000001, seed size 1, search time 4.31ms

minimized
seed 0x9c2a13a600000001, seed size 1, search time 100.03ms

That is, we ended up with this tiny, minimal example:

#[test]
fn test_counter() {
  arbtest::arbtest(|rng| {
    eprintln!("begin trace");
    ...
  })
    .seed(0x9c2a13a600000001);
}
begin trace
0: increment
1: increment
0: unpause
1: unpause

And this is how you properly test concurrent data structures.

Postscript

Of course, this is just a toy. But you can see some ways to extend it. For example, right now our AtomicU32 just delegates to the real one. But what you could do instead is, for each atomic, to maintain a set of values written and, on read, return an arbitrary written value consistent with a weak memory model.

You could also be smarter with exploring interleavings. Instead of interleaving threads at random, like we do here, you can try to apply model checking approaches and prove that you have considered all meaningfully different interleavings.

Or you can apply the approach from Generate All The Things and exhaustively enumerate all interleavings for up to, say, five increments. In fact, why dont we just do this?

$ cargo add exhaustigen

#[test]
fn exhaustytest() {
  let mut g = exhaustigen::Gen::new();
  let mut interleavings_count = 0;

  while !g.done() {
    interleavings_count += 1;
    let counter = Counter::default();
    let mut counter_model: u32 = 0;

    let increment_count = g.gen(5) as u32;
    std::thread::scope(|scope| {
      let t1 = managed_thread::spawn(scope, &counter);
      let t2 = managed_thread::spawn(scope, &counter);

      'outer: while t1.is_paused()
        || t2.is_paused()
        || counter_model < increment_count
      {
        for t in [&t1, &t2] {
          if g.flip() {
            if t.is_paused() {
              t.unpause();
              continue 'outer;
            }
            if counter_model < increment_count {
              t.submit(|c| c.increment());
              counter_model += 1;
              continue 'outer;
            }
          }
        }
        return for t in [t1, t2] {
          t.join()
        };
      }

      assert_eq!(counter_model, counter.get());
    });
  }
  eprintln!("interleavings_count = {:?}", interleavings_count);
}

The shape of the test is more or less the same, except that we need to make sure that there are nodummy iterations, and that we always either unpause a thread or submit an increment.

It finds the same bug, naturally:

thread 'exhaustytest' panicked at src/lib.rs:103:7:
assertion `left == right` failed
  left: 2
 right: 1

But the cool thing is, if we fix the issue by using atomic increment,

impl AtomicU32 {
  pub fn fetch_add(
    &self,
    value: u32,
    ordering: Ordering,
  ) -> u32 {
    pause();
    let result = self.inner.fetch_add(value, ordering);
    pause();
    result
  }
}

impl Counter {
  pub fn increment(&self) {
    self.value.fetch_add(1, SeqCst);
  }
}

we can get a rather specific correctness statements out of our test, that any sequence of at most five increments is correct:

$ t cargo t -r -- exhaustytest --nocapture
running 1 test
all 81133 interleavings are fine!
test exhaustytest ... ok

real 8.65s
cpu  8.16s (2.22s user + 5.94s sys)
rss  63.91mb

And the last small thing. Recall that our PBT minimized the first sequence it found :

begin trace
0: increment
1: increment
0: unpause
1: unpause
1: unpause
0: unpause
0: unpause
1: unpause
0: unpause
0: increment
1: unpause
0: unpause
1: increment
0: unpause
0: unpause
1: unpause
0: unpause
thread 'test_counter' panicked at src/lib.rs:56:7:
assertion `left == right` failed
  left: 4
 right: 3

arbtest failed!
    Seed: 0x4fd7ddff00000020

down to just

begin trace
0: increment
1: increment
0: unpause
1: unpause
thread 'test_counter' panicked at src/lib.rs:57:7:
assertion `left == right` failed
  left: 2
 right: 1

arbtest failed!
    Seed: 0x9c2a13a600000001

But we never implemented shrinking! How is this possible? Well, strictly speaking, this is out of scope for this post. And Ive already described this elsewhere. And, at 32k, this is the third-longest post on this blog. And its 3AM here in Lisbon right now. But of course Ill explain!

The trick is the simplified hypothesis approach. The arbtest PBT library we use in this post is based on a familiar interface of a PRNG:

arbtest::arbtest(|rng| {
  let random_int: usize = rng.int_in_range(0..=100)?;
  let random_bool: bool = rng.arbitrary()?;
  Ok(())
});

But theres a twist! This is a finite PRNG. So, if you ask it to flip a coin it can give you heads. And next time it might give you tails. But if you continue asking it for more, at some point itll give you Err(OutOfEntropy).

Thats why all these ? and the outer loop of while !rng.is_empty() {.

In other words, as soon as the test runs out of entropy, it short-circuits and completes. And that means that by reducing the amount of entropy available the test becomes shorter, and this works irrespective of how complex is the logic inside the test!

And entropy is a big scary word here, what actually happens is that the PRNG is just an &mut &[u8] inside. That is, a slice of random bytes, which is shortened every time you ask for a random number. And the shorter the initial slice, the simpler the test gets. Minimization can be this simple!

You can find source code for this article at https://github.com/matklad/properly-concurrent