Properly Testing Concurrent Data Structures
There’s a fascinating Rust library, loom, which can be used to thoroughly test lock-free data structures. I always wanted to learn how it works. I still do! But recently I accidentally implemented a small toy which, I think, contains some of the loom’s ideas, and it seems worthwhile to write about that. The goal here isn’t to teach you what you should be using in practice (if you need that, go read loom’s docs), but rather to derive a couple of neat ideas from first principles.
One, Two, Three, Two
As usual, we need the simplest possible model program to mess with. The example we use comes from this excellent article. Behold, a humble (and broken) concurrent counter:
use std::sync::atomic::{
AtomicU32,
Ordering::SeqCst,
};
pub struct Counter {
value: AtomicU32,
}
impl Counter {
pub fn increment(&self) {
let value = self.value.load(SeqCst);
self.value.store(value + 1, SeqCst);
}
pub fn get(&self) -> u32 {
self.value.load(SeqCst)
}
}
The bug is obvious here — the increment is not atomic. But what is the best test we can write to expose it?
Trivial Test
The simplest idea that comes to mind is to just hammer the same counter from multiple threads and check the result at the end;
fn threaded_test() {
let counter = Counter::default();
let thread_count = 100;
let increment_count = 100;
std::thread::scope(|scope| {
for _ in 0..thread_count {
scope.spawn(|| {
for _ in 0..increment_count {
counter.increment()
}
});
}
});
assert_eq!(counter.get(), thread_count * increment_count);
}
This fails successfully:
thread 'counter::trivial' panicked:
assertion `left == right` failed
left: 9598
right: 10000
But I wouldn’t call this test satisfactory — it very much depends on the timing, so you can’t reproduce it deterministically and you can’t debug it. You also can’t minimize it — if you reduce the number of threads and increments, chances are the test passes by luck!
PBT
Of course the temptation is to apply property based testing here! The problem almost fits: we have easy-to-generate input (the sequence of increments spread over several threads), a good property to check (result of concurrent increments is identical to that of sequential execution) and the desire to minimize the test.
But just how can we plug threads into a property-based test?
PBTs are great for testing state machines. You can run your state machine through a series of steps where at each step a PBT selects an arbitrary next action to apply to the state:
fn state_machine_test() {
arbtest::arbtest(|rng| {
// This is our state machine!
let mut state: i32 = 0;
// We'll run it for up to 100 steps.
let step_count: usize = rng.int_in_range(0..=100)?;
for _ in 0..step_count {
// At each step, we flip a coin and
// either increment or decrement.
match *rng.choose(&["inc", "dec"])? {
"inc" => state += 1,
"dec" => state -= 1,
_ => unreachable!(),
}
}
Ok(())
});
}
And it feels like we should be able to apply the same technique here. At every iteration, pick a random thread and make it do a single step. If you can step the threads manually, it should be easy to maneuver one thread in between load&store of a different thread.
But we can’t step through threads! Or can we?
Simple Instrumentation
Ok, let’s fake it until we make it! Let’s take a look at the buggy increment method:
pub fn increment(&self) {
let value = self.value.load(SeqCst);
self.value.store(value + 1, SeqCst);
}
Ideally, we’d love to be able to somehow “pause” the thread in-between atomic operations. Something like this:
pub fn increment(&self) {
pause();
let value = self.value.load(SeqCst);
pause();
self.value.store(value + 1, SeqCst);
pause();
}
fn pause() {
// ¯\_(ツ)_/¯
}
So let’s start with implementing our own wrapper for AtomicU32
which includes calls to pause.
use std::sync::atomic::Ordering;
struct AtomicU32 {
inner: std::sync::atomic::AtomicU32,
}
impl AtomicU32 {
pub fn load(&self, ordering: Ordering) -> u32 {
pause();
let result = self.inner.load(ordering);
pause();
result
}
pub fn store(&self, value: u32, ordering: Ordering) {
pause();
self.inner.store(value, ordering);
pause();
}
}
fn pause() {
// still no idea :(
}
Managed Threads API
One rule of a great API design is that you start by implement a single user of an API, to understand how the API should feel, and only then proceed to the actual implementation.
So, in the spirit of faking, let’s just write a PBT using these pausable, managed threads, even if we still have no idea how to actually implement pausing.
We start with creating a counter and two managed threads. And we probably want to pass a reference to the counter to each of the threads:
let counter = Counter::default();
let t1 = managed_thread::spawn(&counter);
let t2 = managed_thread::spawn(&counter);
Now, we want to step through the threads:
while !rng.is_empty() {
let coin_flip: bool = rng.arbitrary()?;
if t1.is_paused() {
if coin_flip {
t1.unpause();
}
}
if t2.is_paused() {
if coin_flip {
t2.unpause();
}
}
}
Or, refactoring this a bit to semantically compress:
let counter = Counter::default();
let t1 = managed_thread::spawn(&counter);
let t2 = managed_thread::spawn(&counter);
let threads = [t1, t2];
while !rng.is_empty() {
for t in &mut threads {
if t.is_paused() && rng.arbitrary()? {
t.unpause()
}
}
}
That is, on each step of our state machine, we loop through all threads and unpause a random subset of them.
But besides pausing and unpausing, we need our threads to actually do something, to increment the
counter. One idea is to mirror the std::spawn
API and pass a closure in:
let t1 = managed_thread::spawn({
let counter = &counter;
move || {
for _ in 0..100 {
counter.increment();
}
}
});
But as these are managed threads, and we want to control them from our tests, lets actually go all the way there and give the controlling thread an ability to change the code running in a managed thread. That is, we’ll start managed threads without a “main” function, and provide an API to execute arbitrary closures in the context of this by-default inert thread (universal server anyone?):
let counter = Counter::default();
// We pass the state, &counter, in, but otherwise the thread is inert.
let t = managed_thread::spawn(&counter);
// But we can manually poke it:
t.submit(|thread_state: &Counter| thread_state.increment());
t.submit(|thread_state: &Counter| thread_state.increment());
Putting everything together, we get a nice-looking property test:
use managed_thread::AtomicU32;
use std::sync::atomic::AtomicU32;
pub struct Counter {
value: AtomicU32,
}
impl Counter {
// ...
}
fn test_counter() {
arbtest::arbtest(|rng| {
// Our "Concurrent System Under Test".
let counter = Counter::default();
// The sequential model we'll compare the result against.
let counter_model: u32 = 0;
// Two managed threads which we will be stepping through
// manually.
let t1 = managed_thread::spawn(&counter);
let t2 = managed_thread::spawn(&counter);
let threads = [t1, t2];
// Bulk of the test: in a loop, flip a coin and advance
// one of the threads.
while !rng.is_empty() {
for t in &mut [t1, t2] {
if rng.arbitrary() {
if t.is_paused() {
t.unpause()
} else {
// Standard "model equivalence" property: apply
// isomorphic actions to the system and its model.
t.submit(|c| c.increment());
counter_model += 1;
}
}
}
}
for t in threads {
t.join();
}
assert_eq!(counter_model, counter.get());
Ok(())
});
}
Now, if only we could make this API work… Remember, our pause
implementation is a shrug emoji!
At this point, you might be mightily annoyed at me for this rhetorical device where I pretend that I
don’t know the answer. No need for annoyance — when writing this code for the first time, I traced
exactly these steps — I realized that I need a “pausing AtomicU32
” so I did that (with dummy
pause calls), then I played with the API I wanted to have, ending at roughly this spot, without
yet knowing how I would make it work or, indeed, if it is possible at all.
Well, if I am being honest, there is a bit of up-front knowledge here. I don’t think we can avoid
spawning real threads here, unless we do something really cursed with inline assembly. When
something calls that pause()
function, and we want it to stay paused until further notice, that
just has to happen in a thread which maintains a stack separate from the stack of our test. And, if
we are going to spawn threads, we might as well spawn scoped threads, so that we can freely borrow
stack-local data. And to spawn a scope thread, you need a
Scope
parameter. So in reality
we’ll need one more level of indentation here:
std::thread::scope(|scope| {
let t1 = managed_thread::spawn(scope, &counter);
let t2 = managed_thread::spawn(scope, &counter);
let threads = [t1, t2];
while !rng.is_empty() {
for t in &mut [t1, t2] {
// ...
}
}
});
Managed Threads Implementation
Now, the fun part: how the heck are we going to make pausing and unpausing work? For starters, there
clearly needs to be some communication between the main thread (t.unpause()
) and the managed
thread (pause()
). And, because we don’t want to change Counter
API to thread some kind of
test-only context, the context needs to be smuggled. So thread_local!
it is. And this context
is going to be shared between two threads, so it must be wrapped in an Arc
.
struct SharedContext {
// 🤷
}
thread_local! {
static INSTANCE: RefCell<Option<Arc<SharedContext>>> =
RefCell::new(None);
}
impl SharedContext {
fn set(ctx: Arc<SharedContext>) {
INSTANCE.with(|it| *it.borrow_mut() = Some(ctx));
}
fn get() -> Option<Arc<SharedContext>> {
INSTANCE.with(|it| it.borrow().clone())
}
}
As usual when using thread_local!
or lazy_static!
, it is convenient to immediately wrap it into
better typed accessor functions. And, given that we are using an Arc
here anyway, we can
conveniently escape thread_local
’s with
by cloning the Arc
.
So now we finally can implement the global pause
function (or at least can kick the proverbial can
a little bit farther):
fn pause() {
if let Some(ctx) = SharedContext::get() {
ctx.pause()
}
}
impl SharedContext {
fn pause(&self) {
// 😕
}
}
Ok, what to do next? We somehow need to coordinate the control thread and the managed thread. And we need some sort of notification mechanism, so that the managed thread knows when it can continue. The most brute force solution here is a pair of a mutex protecting some state and a condition variable. Mutex guards the state that can be manipulated by either of the threads. Condition variable can be used to signal about the changes.
struct SharedContext {
state: Mutex<State>,
cv: Condvar,
}
struct State {
// 🤡
}
Okay, it looks like I am running out of emojies here. There’s no more layers of indirection or infrastructure left, we need to write some real code that actually does do that pausing thing. So let’s say that the state is tracking, well, the state of our managed thread, which can be either running or paused:
enum State {
Running,
Paused,
}
And then the logic of the pause function — flip the state from Running
to Paused
, notify the
controlling thread that we are Paused
, and wait until the controlling thread flips our state back
to Running
:
impl SharedContext {
fn pause(&self) {
let mut guard = self.state.lock().unwrap();
assert_eq!(*guard, State::Running);
*guard = State::Paused;
self.cv.notify_all();
while *guard == State::Paused {
guard = self.cv.wait(guard).unwrap();
}
assert_eq!(*guard, State::Running);
}
}
Aside: Rust’s API for condition variables is beautiful. Condvars are tricky, and I didn’t really
understood them until seeing the signatures of Rust functions. Notice how the wait
function
takes a mutex guard as an argument, and returns a mutex guard. This protects you from the logical
races and guides you towards the standard pattern of using condvars:
First, you lock the mutex around the shared state. Then, you inspect whether the state is what you need. If that’s the case, great, you do what you wanted to do and unlock the mutex. If not, then, while still holding the mutex, you wait on the condition variable. Which means that the mutex gets unlocked, and other threads get the chance to change the shared state. When they do change it, and notify the condvar, your thread wakes up, and it gets the locked mutex back (but the state now is different). Due to the possibility of spurious wake-ups, you need to double check the state and be ready to loop back again to waiting.
Naturally, there’s a helper that encapsulates this whole pattern:
impl SharedContext {
fn pause(&self) {
let mut guard = self.state.lock().unwrap();
assert_eq!(*guard, State::Running);
*guard = State::Paused;
self.cv.notify_all();
guard = self
.cv
.wait_while(guard, |state| *state == State::Paused)
.unwrap();
assert_eq!(*guard, State::Running)
}
}
Ok, this actually does look like a reasonable implementation of pause
. Let’s move on to
managed_thread::spawn
:
fn spawn<'scope, T: 'scope + Send>(
scope: &Scope<'scope, '_>,
state: T,
) {
// ? ? ?? ??? ?????
}
There’s a bunch of stuff that needs to happen here:
-
As we have established, we are going to spawn a (scoped) thread, so we need the
scope
parameter with its three lifetimes. I don’t know how it works, so I am just going by the docs here! -
We are going to return some kind of handle, which we can use to pause and unpause our managed
thread. And that handle is going to be parametrized over the same
'scope
lifetime, because it’ll hold onto the actual join handle. -
We are going to pass the generic state to our new thread, and that state needs to be
Send
, and bounded by the same lifetime as our scoped thread. -
Inside, we are going to spawn a thread for sure, and we’ll need to setup the
INSTANCE
thread local on that thread. -
And it would actually be a good idea to stuff a reference to that
SharedContext
into the handle we return.
A bunch of stuff, in other words. Let’s do it:
struct ManagedHandle<'scope> {
inner: std::thread::ScopedJoinHandle<'scope, ()>,
ctx: Arc<SharedContext>,
}
fn spawn<'scope, T: 'scope + Send>(
scope: &'scope Scope<'scope, '_>,
state: T,
) -> ManagedHandle<'scope> {
let ctx: Arc<SharedContext> = Default::default();
let inner = scope.spawn({
let ctx = Arc::clone(&ctx);
move || {
SharedContext::set(ctx);
drop(state); // TODO: ¿
}
});
ManagedHandle { inner, ctx }
}
The essentially no-op function we spawn looks sus. We’ll fix later! Let’s try to implement
is_paused
and unpause
first! They should be relatively straightforward. For is_paused
, we just
need to lock the mutex and check the state:
impl ManagedHandle<'_> {
pub fn is_paused(&self,) -> bool {
let guard = self.ctx.state.lock().unwrap();
*guard == State::Paused
}
}
For unpause
, we should additionally flip the state back to Running
and notify the other thread:
impl ManagedHandle<'_> {
pub fn unpause(&self) {
let mut guard = self.ctx.state.lock().unwrap();
assert_eq!(*guard, State::Paused);
*guard = State::Running;
self.ctx.cv.notify_all();
}
}
But I think that’s not quite correct. Can you see why?
With this implementation, after unpause
, the controlling and the managed threads will be running
concurrently. And that can lead to non-determinism, the very problem we are trying to avoid here! In
particular, if you call is_paused
right after you unpause
the thread, you’ll most likely get
false
back, as the other thread will still be running. But it might also hit the next pause
call, so, depending on timing, you might also get true
.
What we want is actually completely eliminating all unmanaged concurrency. That means that at any
given point in time, only one thread (controlling or managed) should be running. So the right
semantics for unpause
is to unblock the managed thread, and then block the controlling thread
until the managed one hits the next pause!
impl ManagedHandle<'_> {
pub fn unpause(&self) {
let mut guard = self.ctx.state.lock().unwrap();
assert_eq!(*guard, State::Paused);
*guard = State::Running;
self.ctx.cv.notify_all();
guard = self
.ctx
.cv
.wait_while(guard, |state| *state == State::Running)
.unwrap();
}
}
At this point we can spawn a managed thread, pause it and resume. But right now it doesn’t do anything. Next step is implementing that idea where the controlling thread can directly send an arbitrary closure to the managed one to make it do something:
impl<'scope> ManagedHandle<'scope> {
pub fn submit<F: FnSomething>(&self, f: F)
}
Let’s figure this FnSomething
bound! We are going to yeet this f
over to the managed thread and
run it there once, so it is FnOnce
. It is crossing thread-boundary, so it needs to be + Send
.
And, because we are using scoped threads, it doesn’t have to be 'static
, just 'scope
is
enough. Moreover, in that managed thread the f
will have exclusive access to thread’s state, T
.
So we have:
impl<'scope> ManagedHandle<'scope> {
pub fn submit<F: FnOnce(&mut T) + Send + 'scope>(self, f: F)
}
Implementing this is a bit tricky. First, we’ll need some sort of the channel to actually move the
function. Then, similarly to the unpause
logic, we’ll need synchronization to make sure that the
control thread doesn’t resume until the managed thread starts running f
and hits a pause (or maybe
completes f
). And we’ll also need a new state, Ready
, because now there are two different
reasons why a managed thread might be blocked — it might wait for an unpause
event, or it might
wait for the next f
to execute. This is the new code:
enum State {
Ready,
Running,
Paused,
}
struct ManagedHandle<'scope, T> {
inner: std::thread::ScopedJoinHandle<'scope, ()>,
ctx: Arc<SharedContext>,
sender: mpsc::Sender<Box<dyn FnOnce(&mut T) + 'scope + Send>>,
}
pub fn spawn<'scope, T: 'scope + Send>(
scope: &'scope Scope<'scope, '_>,
mut state: T,
) -> ManagedHandle<'scope, T> {
let ctx: Arc<SharedContext> = Default::default();
let (sender, receiver) =
mpsc::channel::<Box<dyn FnOnce(&mut T) + 'scope + Send>>();
let inner = scope.spawn({
let ctx = Arc::clone(&ctx);
move || {
SharedContext::set(Arc::clone(&ctx));
for f in receiver {
f(&mut state);
let mut guard = ctx.state.lock().unwrap();
assert_eq!(*guard, State::Running);
*guard = State::Ready;
ctx.cv.notify_all()
}
}
});
ManagedHandle { inner, ctx, sender }
}
impl<'scope, T> ManagedHandle<'scope, T> {
pub fn submit<F: FnOnce(&mut T) + Send + 'scope>(&self, f: F) {
let mut guard = self.ctx.state.lock().unwrap();
assert_eq!(*guard, State::Ready);
*guard = State::Running;
self.sender.send(Box::new(f)).unwrap();
guard = self
.ctx
.cv
.wait_while(guard, |state| *state == State::Running)
.unwrap();
}
}
The last small piece of the puzzle is the join
function. It’s almost standard! First we close
our side of the channel. This serves as a natural stop signal for the other thread, so it exits.
Which in turn allows us to join it. The small wrinkle here is that the thread might be paused when
we try to join it, so we need to unpause it beforehand:
impl<'scope, T> ManagedHandle<'scope, T> {
pub fn join(self) {
while self.is_paused() {
self.unpause();
}
drop(self.sender);
self.inner.join().unwrap();
}
}
That’s it! Let’s put everything together!
Helper library, managed_thread.rs
:
use std::{
cell::RefCell,
sync::{atomic::Ordering, mpsc, Arc, Condvar, Mutex},
thread::Scope,
};
pub struct AtomicU32 {
inner: std::sync::atomic::AtomicU32,
}
impl AtomicU32 {
pub fn load(&self, ordering: Ordering) -> u32 {
pause();
let result = self.inner.load(ordering);
pause();
result
}
pub fn store(&self, value: u32, ordering: Ordering) {
pause();
self.inner.store(value, ordering);
pause();
}
}
fn pause() {
if let Some(ctx) = SharedContext::get() {
ctx.pause()
}
}
struct SharedContext {
state: Mutex<State>,
cv: Condvar,
}
enum State {
Ready,
Running,
Paused,
}
thread_local! {
static INSTANCE: RefCell<Option<Arc<SharedContext>>> =
RefCell::new(None);
}
impl SharedContext {
fn set(ctx: Arc<SharedContext>) {
INSTANCE.with(|it| *it.borrow_mut() = Some(ctx));
}
fn get() -> Option<Arc<SharedContext>> {
INSTANCE.with(|it| it.borrow().clone())
}
fn pause(&self) {
let mut guard = self.state.lock().unwrap();
assert_eq!(*guard, State::Running);
*guard = State::Paused;
self.cv.notify_all();
guard = self
.cv
.wait_while(guard, |state| *state == State::Paused)
.unwrap();
assert_eq!(*guard, State::Running)
}
}
pub struct ManagedHandle<'scope, T> {
inner: std::thread::ScopedJoinHandle<'scope, ()>,
sender: mpsc::Sender<Box<dyn FnOnce(&mut T) + 'scope + Send>>,
ctx: Arc<SharedContext>,
}
pub fn spawn<'scope, T: 'scope + Send>(
scope: &'scope Scope<'scope, '_>,
mut state: T,
) -> ManagedHandle<'scope, T> {
let ctx: Arc<SharedContext> = Default::default();
let (sender, receiver) =
mpsc::channel::<Box<dyn FnOnce(&mut T) + 'scope + Send>>();
let inner = scope.spawn({
let ctx = Arc::clone(&ctx);
move || {
SharedContext::set(Arc::clone(&ctx));
for f in receiver {
f(&mut state);
let mut guard = ctx.state.lock().unwrap();
assert_eq!(*guard, State::Running);
*guard = State::Ready;
ctx.cv.notify_all()
}
}
});
ManagedHandle { inner, ctx, sender }
}
impl<'scope, T> ManagedHandle<'scope, T> {
pub fn is_paused(&self) -> bool {
let guard = self.ctx.state.lock().unwrap();
*guard == State::Paused
}
pub fn unpause(&self) {
let mut guard = self.ctx.state.lock().unwrap();
assert_eq!(*guard, State::Paused);
*guard = State::Running;
self.ctx.cv.notify_all();
guard = self
.ctx
.cv
.wait_while(guard, |state| *state == State::Running)
.unwrap();
}
pub fn submit<F: FnOnce(&mut T) + Send + 'scope>(&self, f: F) {
let mut guard = self.ctx.state.lock().unwrap();
assert_eq!(*guard, State::Ready);
*guard = State::Running;
self.sender.send(Box::new(f)).unwrap();
guard = self
.ctx
.cv
.wait_while(guard, |state| *state == State::Running)
.unwrap();
}
pub fn join(self) {
while self.is_paused() {
self.unpause();
}
drop(self.sender);
self.inner.join().unwrap();
}
}
System under test, not-exactly-atomic counter:
use std::sync::atomic::Ordering::SeqCst;
use managed_thread::AtomicU32;
use std::sync::atomic::AtomicU32;
pub struct Counter {
value: AtomicU32,
}
impl Counter {
pub fn increment(&self) {
let value = self.value.load(SeqCst);
self.value.store(value + 1, SeqCst);
}
pub fn get(&self) -> u32 {
self.value.load(SeqCst)
}
}
And the test itself:
fn test_counter() {
arbtest::arbtest(|rng| {
eprintln!("begin trace");
let counter = Counter::default();
let mut counter_model: u32 = 0;
std::thread::scope(|scope| {
let t1 = managed_thread::spawn(scope, &counter);
let t2 = managed_thread::spawn(scope, &counter);
let mut threads = [t1, t2];
while !rng.is_empty() {
for (tid, t) in threads.iter_mut().enumerate() {
if rng.arbitrary()? {
if t.is_paused() {
eprintln!("{tid}: unpause");
t.unpause()
} else {
eprintln!("{tid}: increment");
t.submit(|c| c.increment());
counter_model += 1;
}
}
}
}
for t in threads {
t.join();
}
assert_eq!(counter_model, counter.get());
Ok(())
})
});
}
Running it identifies a failure:
---- test_counter stdout ----
begin trace
0: increment
1: increment
0: unpause
1: unpause
1: unpause
0: unpause
0: unpause
1: unpause
0: unpause
0: increment
1: unpause
0: unpause
1: increment
0: unpause
0: unpause
1: unpause
0: unpause
thread 'test_counter' panicked at src/lib.rs:56:7:
assertion `left == right` failed
left: 4
right: 3
arbtest failed!
Seed: 0x4fd7ddff00000020
Which … is something we got like 5% into this article already, with normal threads! But there’s more to this failure. First, it is reproducible. If I specify the same seed, I get the exact same interleaving:
fn test_counter() {
arbtest::arbtest(|rng| {
eprintln!("begin trace");
...
})
.seed(0x71aafcd900000020);
}
And this is completely machine independent! If you specify this seed, you’ll get exact same interleaving. So, if I am having trouble debugging this, I can DM you this hex in Zulip, and you’ll be able to help out!
But there’s more — we don’t need to debug this failure, we can minimize it!
fn test_counter() {
arbtest::arbtest(|rng| {
eprintln!("begin trace");
...
})
.seed(0x71aafcd900000020)
.minimize();
}
This gives me the following minimization trace:
begin trace
0: increment
1: increment
0: unpause
1: unpause
1: unpause
0: unpause
0: unpause
1: unpause
0: unpause
0: increment
1: unpause
0: unpause
1: increment
0: unpause
0: unpause
1: unpause
0: unpause
seed 0x4fd7ddff00000020, seed size 32, search time 106.00ns
begin trace
0: increment
1: increment
0: unpause
0: unpause
1: unpause
0: unpause
1: unpause
0: unpause
1: unpause
1: unpause
1: increment
seed 0x540c0c1c00000010, seed size 16, search time 282.16µs
begin trace
0: increment
1: increment
0: unpause
1: unpause
1: unpause
1: unpause
seed 0x084ca71200000008, seed size 8, search time 805.74µs
begin trace
0: increment
1: increment
0: unpause
1: unpause
seed 0x5699b19400000004, seed size 4, search time 1.44ms
begin trace
0: increment
1: increment
0: unpause
1: unpause
seed 0x4bb0ea5c00000002, seed size 2, search time 4.03ms
begin trace
0: increment
1: increment
0: unpause
1: unpause
seed 0x9c2a13a600000001, seed size 1, search time 4.31ms
minimized
seed 0x9c2a13a600000001, seed size 1, search time 100.03ms
That is, we ended up with this tiny, minimal example:
fn test_counter() {
arbtest::arbtest(|rng| {
eprintln!("begin trace");
...
})
.seed(0x9c2a13a600000001);
}
begin trace
0: increment
1: increment
0: unpause
1: unpause
And this is how you properly test concurrent data structures.
Postscript
Of course, this is just a toy. But you can see some ways to extend it. For example, right now our
AtomicU32
just delegates to the real one. But what you could do instead is, for each atomic, to
maintain a set of values written and, on read, return an arbitrary written value consistent with a
weak memory model.
You could also be smarter with exploring interleavings. Instead of interleaving threads at random, like we do here, you can try to apply model checking approaches and prove that you have considered all meaningfully different interleavings.
Or you can apply the approach from Generate All The Things and exhaustively enumerate all interleavings for up to, say, five increments. In fact, why don’t we just do this?
$ cargo add exhaustigen
fn exhaustytest() {
let mut g = exhaustigen::Gen::new();
let mut interleavings_count = 0;
while !g.done() {
interleavings_count += 1;
let counter = Counter::default();
let mut counter_model: u32 = 0;
let increment_count = g.gen(5) as u32;
std::thread::scope(|scope| {
let t1 = managed_thread::spawn(scope, &counter);
let t2 = managed_thread::spawn(scope, &counter);
'outer: while t1.is_paused()
|| t2.is_paused()
|| counter_model < increment_count
{
for t in [&t1, &t2] {
if g.flip() {
if t.is_paused() {
t.unpause();
continue 'outer;
}
if counter_model < increment_count {
t.submit(|c| c.increment());
counter_model += 1;
continue 'outer;
}
}
}
return for t in [t1, t2] {
t.join()
};
}
assert_eq!(counter_model, counter.get());
});
}
eprintln!("interleavings_count = {:?}", interleavings_count);
}
The shape of the test is more or less the same, except that we need to make sure that there are no “dummy” iterations, and that we always either unpause a thread or submit an increment.
It finds the same bug, naturally:
thread 'exhaustytest' panicked at src/lib.rs:103:7:
assertion `left == right` failed
left: 2
right: 1
But the cool thing is, if we fix the issue by using atomic increment, …
impl AtomicU32 {
pub fn fetch_add(
&self,
value: u32,
ordering: Ordering,
) -> u32 {
pause();
let result = self.inner.fetch_add(value, ordering);
pause();
result
}
}
impl Counter {
pub fn increment(&self) {
self.value.fetch_add(1, SeqCst);
}
}
… we can get a rather specific correctness statements out of our test, that any sequence of at most five increments is correct:
$ t cargo t -r -- exhaustytest --nocapture
running 1 test
all 81133 interleavings are fine!
test exhaustytest ... ok
real 8.65s
cpu 8.16s (2.22s user + 5.94s sys)
rss 63.91mb
And the last small thing. Recall that our PBT minimized the first sequence it found …:
begin trace
0: increment
1: increment
0: unpause
1: unpause
1: unpause
0: unpause
0: unpause
1: unpause
0: unpause
0: increment
1: unpause
0: unpause
1: increment
0: unpause
0: unpause
1: unpause
0: unpause
thread 'test_counter' panicked at src/lib.rs:56:7:
assertion `left == right` failed
left: 4
right: 3
arbtest failed!
Seed: 0x4fd7ddff00000020
… down to just
begin trace
0: increment
1: increment
0: unpause
1: unpause
thread 'test_counter' panicked at src/lib.rs:57:7:
assertion `left == right` failed
left: 2
right: 1
arbtest failed!
Seed: 0x9c2a13a600000001
But we never implemented shrinking! How is this possible? Well, strictly speaking, this is out of scope for this post. And I’ve already described this elsewhere. And, at 32k, this is the third-longest post on this blog. And it’s 3AM here in Lisbon right now. But of course I’ll explain!
The trick is the simplified hypothesis approach. The arbtest PBT library we use in this post is based on a familiar interface of a PRNG:
arbtest::arbtest(|rng| {
let random_int: usize = rng.int_in_range(0..=100)?;
let random_bool: bool = rng.arbitrary()?;
Ok(())
});
But there’s a twist! This is a finite PRNG. So, if you ask it to flip a coin it can give you
heads. And next time it might give you tails. But if you continue asking it for more, at some point
it’ll give you Err(OutOfEntropy)
.
That’s why all these ?
and the outer loop of
while !rng.is_empty() {
.
In other words, as soon as the test runs out of entropy, it short-circuits and completes. And that means that by reducing the amount of entropy available the test becomes shorter, and this works irrespective of how complex is the logic inside the test!
And “entropy” is a big scary word here, what actually happens is that the PRNG is just an &mut
&[u8]
inside. That is, a slice of random bytes, which is shortened every time you ask for a random
number. And the shorter the initial slice, the simpler the test gets. Minimization can be this
simple!
You can find source code for this article at https://github.com/matklad/properly-concurrent