Min of Three Part 2
This is the continuation of the previous post about optimizing 2D grid based dynamic programming algorithm for CPU level parallelism.
In The Previous Episode
This is the code we are trying to make faster:
Code on Rust playground (293 ms)
It calculates dynamic time warping distance between two double
vectors using an update rule which is structured like this:
This code takes 293 milliseconds to run on a particular input data. The speedup from 435 milliseconds stated in the previous post is due to Moore’s law: I’ve upgraded the CPU :)
We can bring run time down by tweaking how we calculate the minimum of three elements.
Code on Rust playground (210 ms)
This version takes only 210 milliseconds, presumably because the minimum of two elements in the previous row can be calculated without waiting for the preceding element in the current row to be computed.
The assembly for the main loop looks like this (AT&T syntax, destination register on the right)
Check the previous post for more details!
The parallel plan
Can we loosen dependencies between cells even more to benefit from instruction level parallelism? What if instead of filling the table row by row, we do it diagonals?
We’d need to remember two previous diagonals instead of one previous row, but all the cells on the next diagonal would be independent! In theory, compiler should be able to use SIMD instructions to make the computation truly parallel.
Implementation Plan
Coding up this diagonal traversal is a bit tricky, because you need to map linear vector indices to diagonal indices.
The original indexing worked like this:
-
ix
andiy
are indices in the input vectors. -
The outer loop is over
ix
. -
On each iteration, we remember two rows (
curr
andprev
in the code).
For our grand plan, we need to fit a rhombus peg in a square hole:
-
id
is the index of the diagonal. There are twice as much diagonals as rows. -
The outer loop is over
id
. -
On each iteration we remember three columns (
d1
,d2
d3
in the code). - There is a phase transition once we’ve crossed the main diagonal.
-
We can derive
iy
from the fact thatix + iy = id
.
Code
The actual code looks like this:
Code on Rust playground (185 ms)
It take 185 milliseconds to run. The assembly for the main loop is quite interesting:
First of all, we don’t see any vectorized instructions, the code does roughly the same operations as the in previous version. Also, there is a whole bunch of extra branching instructions on the top. These are bounds checks which were not eliminated this time. And this is great: if I add all off-by one errors I’ve made implementing diagonal indexing, I would get an integer overflow! Nevertheless, we’ve got some speedup.
Can we go further and add get SIMD instructions here? At the moment, Rust does not have a stable way to explicitly emit SIMD (it’s going to change some day) (UPDATE: we have SIMD on stable now!), so the only choice we have is to tweak the source code until LLVM sees an opportunity for vectorization.
SIMD
Although bounds checks themselves don’t slow down the code that much,
they can prevent LLVM from vectorizing. So let’s dip our toes into
unsafe
:
Code on Rust playground (52 ms)
The code is as fast as it is ugly: it finishes in whooping 52 milliseconds! And of course we see SIMD in the assembly:
Safe SIMD
How can we get the same results with safe Rust? One possible way is to
use iterators, but in this case the resulting code would be rather
ugly, because you’ll need a lot of nested .zip
’s. So let’s try a
simple trick of hoisting the bounds checks of the loop. The idea is to
transform this:
into this:
In Rust, this is possible by explicitly slicing the buffer before the loop:
Code on Rust playground (107 ms)
This is definitely an improvement over the best safe version, but is
still twice as slow as the unsafe variant. Looks like some bounds
checks are still there! It is possible to find them by selectively
using unsafe
to replace some indexing operations.
And it turns out that only ys
is still checked!
Code on Rust playground (52 ms)
If we use unsafe
only for ys
, we regain all the performance.
LLVM is having trouble iterating ys
in reverse, but the fix is easy:
just reverse it once at the beginning of the function:
Code on Rust playground (50 ms)
Conclusions
We’ve gone from almost 300 milliseconds to only 50 in safe Rust. That is quite impressive! However, the resulting code is rather brittle and even small changes can prevent vectorization from triggering.
It’s also important to understand that to allow for SIMD, we had to change the underlying algorithm. This is not something even a very smart compiler could do!