Systems and
Formalisms Lab

Software optimization for a RISC-V Accelerator

A story about writing high-performance code for a custom accelerator in a RISC-V CPU, with the help of some semi-automated tools.

Writing high-performance software these days is a challenging task. Today's hardware ecosystem is highly heterogeneous, with a trend towards application-tailored solutions. Despite years of effort in automation, by and large the burden still falls on the developer to write a program in a manner that can take advantage of this hardware. Oftentimes, the solution is to use a library carefully written by a performance engineer for the specific combination of application and hardware. But why is that? What happens when there is no library?

Take the example of dense matrix multiplication, a ubiquitous workload in deep learning and HPC in general. The code at the left employs a simple naive implementation, while the code at the right is the corresponding implementation in the popular OpenBLAS linear algebra library. Rather, it's just a part of the implementation. This is just one highly optimized "microkernel" to compute matrix multiplication on a fixed size matrix, for Intel Haswell CPUs. One of these exists for every generation of CPU introducing new SIMD instructions, or even changing certain microarchitectural details. On top of that there is more code, not shown here, to "tile" a generic size matrix multiplication so it utilizes these microkernels. In all, OpenBLAS ends up at almost 4 million lines of code, 100x more than that of the reference BLAS implementation. Clearly, this represents a huge engineering effort.

Side by side of naive matmul (0.33 GFLOPS) and ~500-line OpenBLAS microkernel (495.19 GFLOPS).

Comparison between a naive matmul and an OpenBLAS microkernel for GEMM.

But the performance difference is stark: a 1487x speedup is observed testing these programs on 1024x1024 matrices. Between these performance gains, the lack of compelling automated solutions, and the complexity of writing such optimized programs by hand, it's clear why these libraries have become the gold standard.

This reality is especially bleak for the hardware vendor looking to sell a new domain-specific accelerator, for instance. There's a monumental hurdle to stand up a new platform with the libraries programmers are used to, due to all this performance engineering that needs to be caught up with.

As such, there has been considerable interest in reducing this effort. Although fully automated compilation is still actively researched, a semi-automated language-level solution, user scheduling, has proved particularly effective. The idea is to express an optimized program as a simpler algorithm with a series of optimizations, called a "schedule," applied to it. The language in turn makes this separation explicit: the programmer writes out the algorithm for their program, and expresses their schedule as a series of transformations on this program. By decoupling correctness and speed in this way, the goal is to make these programs easier to iterate on, as well as to maintain and audit.

This blog post is a story about using user-scheduled languages to develop high-performance software for a particular custom accelerator. We'll start with the background behind our platform and the software we wanted to run.

Background

The hardware in question accelerates 4x4 dense matrix multiplications with a systolic array. It is implemented as an ISA extension to RISC-V, RVM (RISC-V Matrix) [1], allowing it to be programmed through CPU instructions. In particular, our implementation is connected to a microcontroller called X-HEEP. For those familiar with RVV, the RISC-V Vector extension, RVM works similarly. There are a set of 8 "tile" registers in the accelerator, each storing 4x4 matrices of 32-bit values. Arithmetic instructions compute operations on these tile registers, and load/store instructions transfer data between the register file and main memory. These instructions are issued by the CPU. For example, mmasa.w m2, m0, m1 computes m 2 + = m 0 m 1 T . The full ISA listing is available here.

One notable property of this accelerator combined with the CPU is the way its instructions are handled. The CPU issues instructions in-order, including RVM instructions, but the accelerator is allowed to complete out of order. Consequently, the CPU can perform and retire instructions while the accelerator completes operations. As we will see, optimizing software for this accelerator is in large part a matter of devising how to best overlap the time spent on the CPU with the accelerator.

On the software side, we were mainly interested in embedded machine learning applications (tinyML). A common workload in this space is convolution, specifically the 1D variant. This routine is a core part of neural networks designed to recognize speech, or monitor EKGs in wearable devices. Thus, we chose to study a simplified 1D convolution layer. Basic familiarity of convolution is assumed, but the mathematical specification for this particular operator is as follows. Given an input stream I [ I C ] [ N ] of length N with I C input channels, and a array of kernels K [ O C ] [ I C ] [ W ] of width W with O C output channels, we intend to compute an output O [ O C ] [ N ] such that

O [ i ] [ j ] = c = 0 I C r = 0 W { I [ c ] [ j + r ] K [ i ] [ c ] [ r ] if  j + r < N 0 otherwise , 0 i O C , 0 j N

Before we jump into the process using user-scheduled languages to make this development process easier, let's look at how it is to write this software by hand, to serve as a baseline.

Manual optimization

void conv1d_cpu(int32_t *data, int32_t *kernels, int32_t *out) {
    for (int i = 0; i < OC; i++) {
        for (int j = 0; j < N; j++) {
            out[i][j] = 0;
            for (int c = 0; c < IC; c++) {
                for (int r = 0; r < W; r++) {
                    if ((j + r) < N) {
                        out[i][j] += data[c][j + r] * kernels[i][c][r];
                    }
                }
            }
        }
    }
}

The above code is more or less a direct translation of the formula written above. While the arrays I , K , O from the formula are passed as pointers data, kernels,out, note that the accesses use bracket syntax for multidimensional arrays for readability. In reality, an expression such as out[i*IW+j] would be needed instead of out[i][j].

For simplicity, we'll make a lot of assumptions in the following sections. Notably, whenever we tile a loop by some factor, we assume the factor divides evenly. Compensating for this is usually a matter of adding some loop epilogue to handle the remainder. Another large assumption is that the size of a kernel matches the tile size supported by the accelerator. Usually, the kernel will be smaller, so the tile needs to be padded. While this assumption is unrealistic, we will see that there is still a great deal of nuance to study in optimizing this simplified routine.

Our first order of business is to actually make use of our accelerator hardware. Currently, everything is running on the CPU. For that, we need to somehow express this computation in terms of matrix multiplication. Notice that the inner loops vaguely resemble the accumulation of a matrix multiply. For each element of the out matrix, we are doing a dot product between a row and column of data and kernels respectively. The one caveat is that we don't access one row of data at a time, but rather an irregular pattern given by j+r. We can thus separate out this access into its own loop, storing the result as another array, which we will call y:

void conv1d_im2col(int32_t *data, int32_t *kernels, int32_t *out) {
    int32_t y[N][IC][W];
    // perform im2col
    for (int j = 0; j < N; j++) {
        for (int c = 0; c < IC; c++) {
            for (int r = 0; r < W; r++) {
                if ((j + r) < N) {
                    y[j][c][r] = data[c][j+r];
                } else {
                    y[j][c][r] = 0;
                }
            }
        }
    }
    // matrix multiplication
    for (int i = 0; i < OC; i++) {
        for (int j = 0; j < N; j++) {
            out[i][j] = 0;
            for (int c = 0; c < IC; c++) {
                for (int r = 0; r < W; r++) {
                    out[i][j] += y[j][c][r] * kernels[i][c][r];
                }
            }
        }
    }
}

This transformation is so common it has the name im2col, as matrix multiply hardware is often used to accelerate convolutions. Having separated out these sections we could consider writing an optimized tiled matrix multiplication routine and invoking that, the same way that OpenBLAS solves this problem. However, this would be missing a key property of our hardware. Recall that accelerator instructions do not block the CPU. So, we could overlap the time spent computing results with time spent repacking the data (im2col). Instead of doing im2col on the entire array, we should repack only a tile (the size the accelerator can handle) at a time. This means we are computing as much as possible as soon as the data is ready.

#define TILE 4
void conv1d_im2col_tile(int32_t *data, int32_t *kernels, int32_t *out) {
    for (int tile_i = 0; tile_i < OC/TILE; tile_i++) {
        for (int tile_j = 0; tile_j < N/TILE; tile_j++) {
            for (int c = 0; c < IC; c++) {
                int32_t y[TILE][TILE];
                // perform im2col
                for (int j = 0; j < TILE; j++) {
                    // assumed that W == TILE!
                    for (int r = 0; r < TILE; r++) {
                        if (((tile_j*TILE + j) + r) < N) {
                            y[j][r] = data[c][(tile_j*TILE + j)+r];
                        } else {
                            y[j][r] = 0;
                        }
                    }
                }
                // matrix multiplication
                for (int i = 0; i < TILE; i++) {
                    for (int j = 0; j < TILE; j++) {
                        out[i][j] = 0;
                        for (int r = 0; r < TILE; r++) {
                            out[i][j] += y[j][r]
                             * kernels[tile_i*TILE + i][c][r];
                        }
                    }
                }
            }
        }
    }
}

Note some subtle changes in this process: the order of loops has been changed from the original program. Before, for each input channel (c), we accumulated one scalar output result. Now, after having tiled the i and j loops, we are accumulating a 4x4 tile.

Now, the operations in the routine correspond nicely with the instructions supported by our accelerator. Instead of performing the 4x4 matrix multiplication on the CPU, we can directly offload this to the accelerator. We can also hold the intermediate result out[i][j] until the end of the loop, when we can store the accumulated register to main memory. To properly load the subset of the matrices into tile registers, we used the stride parameter of the load instruction, which represents the width of a row in bytes. For instance, when loading kernels, the width of a row is I C W , so we pass 4 I C W (4 = sizeof(int32)).

#define TILE 4
void conv1d_im2col_tile(int32_t *data, int32_t *kernels, int32_t *out) {
    for (int tile_i = 0; tile_i < OC/TILE; tile_i++) {
        for (int tile_j = 0; tile_j < IW/TILE; tile_j++) {
            asm volatile ("mzero m2");
            for (int c = 0; c < IC; c++) {
                int32_t y[TILE][TILE];
                // perform im2col
                for (int j = 0; j < TILE; j++) {
                    for (int r = 0; r < TILE; r++) {
                        if (((tile_j*TILE + j) + r) < N) {
                            y[j][r] = data[c][(tile_j*TILE + j)+r];
                        } else {
                            y[j][r] = 0;
                        }
                    }
                }
                // matrix multiplication
                asm volatile ("mld.w m0, (%0), %1"
                    :: "r"(y), "r"(TILE*4));
                asm volatile ("mld.w m1, (%0), %1"
                    :: "r"(&kernels[tile_i*TILE][c][0]), "r"(IC * W * 4));
                asm volatile ("mmasa.w m2, m0, m1");
            }
            asm volatile ("mst.w m2, (%0), %1"
                :: "r"(&out[tile_i*TILE][tile_j*TILE]), "r"(IW * 4));
        }
    }
}

At this point, this code performs around 4x faster than the scalar code we started with. [2] Still, we can further optimize it. Profiling the code reveals that the majority of the time is still spent simply doing im2col, and that the computation practically adds nothing to the total runtime, once again due to the nonblocking nature of the instructions. There is ample time for the matrix load and multiply to compute before the im2col loop provides the next piece of data. However, notice that the im2col result is unnecessarily computed for every tile_i iteration: the result is not dependent on tile_i at all. If we could reorder the loops and share the value of y for every iteration of tile_i, then in theory we may be able to speed up by a factor of up to O C T I L E . In reality, since we are reducing over the c loop, we would need to store the tiles for each tile_i in different registers, which is not feasible as we only have 8. But 8 registers is still enough registers to store 4 different tile_i iterations, so we can unroll by a factor of 4.

#define TILE 4
void conv1d_im2col_tile(int32_t *data, int32_t *kernels, int32_t *out) {
    for (int tile_i = 0; tile_i < OC/(TILE*4); tile_i++) {
        for (int tile_j = 0; tile_j < IW/TILE; tile_j++) {
            asm volatile("mzero m1");
            asm volatile("mzero m2");
            asm volatile("mzero m3");
            asm volatile("mzero m4");
            for (int c = 0; c < IC; c++) {
                int32_t y[TILE][TILE];
                for (int j = 0; j < TILE; j++) {
                    for (int r = 0; r < TILE; r++) {
                        if (((tile_j*TILE + j) + r) < N) {
                            y[j][r] = data[c][(tile_j*TILE + j)+r];
                        } else {
                            y[j][r] = 0;
                        }
                    }
                }
                // matrix multiplication
                asm volatile("mld.w m0, (%0), %1"
                    :: "r"(y), "r"(TILE * 4));
                asm volatile("mld.w m5, (%0), %1"
                    :: "r"(kernel_base), "r"(IC * KW * 4));
                asm volatile("mmasa.w m1, m0, m5");
                asm volatile("mld.w m6, (%0), %1"
                    :: "r"(kernel_base+TILE * IC * KW), "r"(IC * KW * 4));
                asm volatile("mmasa.w m2, m0, m6");
                asm volatile("mld.w m7, (%0), %1"
                    :: "r"(kernel_base+TILE * IC * KW*2), "r"(IC * KW * 4));
                asm volatile("mmasa.w m3, m0, m7");
                asm volatile("mld.w m5, (%0), %1"
                    :: "r"(kernel_base+TILE * IC * KW*3), "r"(IC * KW * 4));
                asm volatile("mmasa.w m4, m0, m5");
            }
            asm volatile("mst.w m1, (%0), %1"
                :: "r"(&out[tile_i*TILE][tile_j*TILE]), "r"(IW * 4));
            asm volatile("mst.w m2, (%0), %1"
                :: "r"(&out[tile_i*TILE+TILE*1][tile_j*TILE]), "r"(IW * 4));
            asm volatile("mst.w m3, (%0), %1"
                :: "r"(&out[tile_i*TILE+TILE*2][tile_j*TILE]), "r"(IW * 4));
            asm volatile("mst.w m4, (%0), %1"
                :: "r"(&out[tile_i*TILE+TILE*3][tile_j*TILE]), "r"(IW * 4));
        }
    }
}

As we would expect, the new code yields another roughly 4x speedup from the previous iteration.

Here we saw firsthand the impacts of optimizing for specialized hardware. We saw a roughly 16x speedup from our final to initial routines, but it was also 5x the number of lines of code. Notably, we lost the connection to the mathematical formula we started with. While the correctness of the original code was straightforward to audit at a glance, our final routine employed bespoke inline assembly and data movement techniques which don't readily correspond to the original specification. Any further maintenance on this code requires thinking through these techniques, and understanding their correctness, to manipulate it, which only becomes a bigger problem as more optimizations are applied.

Now, let's see how user-scheduled languages may improve this workflow.

Exo

Exo is a user-scheduled language implemented as a DSL embedded in Python. The programmer writes their algorithm in the Exo language, which resembles plain Python. Then, they manipulate the schedule of this program using a handful of scheduling directives, such as unrolling a loop, or reordering two loops in a nest. These directives are written as ordinary Python functions manipulating the object corresponding to the program, which inform changes in the program's AST. Finally, the code is lowered down to C, which can be processed with a standard C compiler.

Let's dive right in with the convolution routine. We'll start by expressing our algorithm, which corresponds to the "direct translation" version we started with in C:

@proc
def generic_conv1d(
    data: i32[IC, N],
    kernels: i32[OC, IC, W],
    out: i32[OC, N],
):
    # do the convolution
    for i in seq(0, OC):
        for j in seq(0, N):
            # zero out the result memory
            out[i, j] = 0.0
            for c in seq(0, IC):
                for r in seq(0, W):
                    y: i32
                    if j + r < W:
                        y = data[c, j + r]
                    else:
                        y = 0
                    out[i, j] += kernels[i, c, r] * y

To optimize this routine, we'll pass the newly defined function generic_conv1d as an object to Exo's scheduling directives, which are just Python functions. The return value is a new procedure with the rewrite applied, which we can pass to further directives. We continue the process until we have arrived at a satisfactory schedule.

Along with the program itself, Exo's scheduling directives often need to be passed locations in the program to manipulate. For example, we may want to tell Exo to unroll one specific loop, rather than all of the loops in the program. One option Exo provides is to pass a string which is pattern matched against the program. So, to unroll a for loop with index "i", one could write p = unroll_loop(p, "for i in _:_") where p is the procedure object.

Better yet, Exo provides a system called "cursors," which lets you maintain a reference to a certain location, carried throughout the transformations you make on the program. In our case, we'll be repeatedly manipulating a lot of the same loops, so grabbing some cursors early on will help a lot with readability:

# Before scheduling, grab cursors to the object code.
i_loop = p.find("for i in _:_")
j_loop = p.find("for j in _:_")
c_loop = p.find("for c in _:_")
y_alloc = p.find("y : _")
y_assign = p.find("y = data[_]")

Note that in this snippet and those that follow, p refers to the generic_conv1d routine. We've shortened it for brevity; these snippets belong to a function taking p as a parameter, which we are passing generic_conv1d to at the top level.

Having defined some useful cursors, we can begin scheduling the program. We'll go about things in a different order than we presented by hand, which was first im2col, tile for the accelerator's register, tile again to run 4 iterations in parallel, then reorder and unroll the 4 iterations. While this was an intuitive way to understand the performance evolution of the program, it's harder to express as transformations on the program when written in this order.

Even though we'll present it here in the order most conducive to Exo, a big benefit of a system like Exo is that the schedule can be written in any order. In fact, we originally wrote this schedule in the same order as our manual optimizations. The upside of having the schedule written out explicitly is that it's quite easy to go back and revise it somewhere in the middle, incrementally improving it with new directives. Writing by hand, every optimization we added was carried out on the result of all those before it, requiring a sort of global reasoning about the program.

With that said, we'll start with all of our tiling:

# Tile outer loops to TILE size for RVM
p, _ = tile_loops(p, [(i_loop, TILE), (j_loop, TILE)], perfect=True)
# Compute 4 registers at a time
p, _ = tile_loops(p, [(i_loop, 4)], perfect=True)

We're using the scheduling directive tile_loops, passing the program p, along with the cursors we selected and the factor by which we'd like to tile by. In this case, we'd like to tile the i and j loops by a factor of 4 both, corresponding to our accelerator. Once again, we're doing things out of order, so here we're also going to tile again for the 4x compute optimization we made at the end.

# Exo adds "o" and "i" suffix for outer and inner tiled loops respectively
i_loop_reg = p.find("for ioi in _:_")
p = reorder_loops(p, i_loop_reg)

Finally, we also want to reorder this new loop corresponding to the 4 registers so that we can unroll it on the innermost loops.

In Exo, we can print() the program at any point to see the result of our scheduling. Printing p right now yields:

def exo_conv1d_tile_lt_kw(data: i32[4, 16] @ DRAM,
                          kernels: i32[16, 4, 4] @ DRAM,
                          out: i32[16, 16] @ DRAM):
    for ioo in seq(0, 1):
        for jo in seq(0, 4):
            for ioi in seq(0, 4):
                for ii in seq(0, 4):
                    for ji in seq(0, 4):
                        out[ii + 4 * ioi + 16 * ioo, ji + 4 * jo] = 0.0
                        for c in seq(0, 4):
                            for r in seq(0, 4):
                                y: i32 @ DRAM
                                if ji + r + 4 * jo < 4:
                                    y = data[c, ji + r + 4 * jo]
                                else:
                                    y = 0
                                out[ii + 4 * ioi + 16 * ioo, ji +
                                    4 * jo] += kernels[ii + 4 * ioi + 16 * ioo,
                                                       c, r] * y

What we're going to aim for in the steps that follow is to expose the parts of this computation that can be offloaded to the accelerator. Much like we exposed matrix multiplication by hand with im2col so it was clear to us as programmers, we're going to "stage" the memory accesses into buffers matching the size supported by our accelerator, so we can later tell Exo to use our instructions instead.

We'll start with out[]. We want the compute loops to be operating on our staged buffer (which will eventually become registers) so we do that first. stage_mem is a directive provided by Exo that replaces all accesses to an array in some loop with a new array, and then inserts code after the loop to copy it back. We use stage_mem here to introduce out_tile, replacing all accesses in the c loop:

# Stage output to out_tile
p, (out_alloc, out_tile, body, _) = auto_stage_mem(
    p, p.find_loop("c").expand(1, 0), "out", "out_tile", rc=True
)

Printing p gives us

def exo_conv1d_tile_lt_kw(data: i32[4, 16] @ DRAM,
                       kernels: i32[16, 4, 4] @ DRAM,
                       out: i32[16, 16] @ DRAM):
 for ioo in seq(0, 1):
     for jo in seq(0, 4):
         for ioi in seq(0, 4):
             for ii in seq(0, 4):
                 for ji in seq(0, 4):
                     out_tile: i32 @ DRAM
                     out_tile = 0.0
                     for c in seq(0, 4):
                         for r in seq(0, 4):
                             y: i32 @ DRAM
                             if ji + r + 4 * jo < 16:
                                 y = data[c, ji + r + 4 * jo]
                             else:
                                 y = 0
                             out_tile += kernels[ii + 4 * ioi + 16 * ioo, c,
                                                 r] * y
                     out[ii + 4 * ioi + 16 * ioo, ji + 4 * jo] = out_tile

Now all the code at the inside of this loop operates on out_tile instead. But we want out_tile to correspond to the size of a register. Actually, for each iteration, recall that we're trying to handle 4 output tiles. So we want it to be the size of 4 tiles. For this we should lift this scalar allocation out of the loops ii,jj,ioi which corresponds to the dimensions of these tiles, and turn it into a buffer of that size. In the end, out_tile will be a 3D dimensional array: 4 registers of 4x4 each.

Exo provides a directive in its standard library called lift_alloc, which moves the allocation out of a loop, and another called expand_dim which adds another dimension to a buffer dependent on some index variable. We can easily repeat these directives for each loop we want to lift out of. But even better, we can take advantage of the fact that Exo directives are just Python functions: we can create new rules composing existing ones. So here, we instead make a new function which repeats the lift_alloc - expand_dim process until some threshold size is reached:

def autolift_alloc(p, alloc_c, dep_set=None, max_size=0, lift=True):
    """
    for i in seq(0, 10):
        for j in seq(0, 20):
            a : R          <- alloc_c, dep_set = {'i'}
            a[i] = ...
    ---->
    a : R[10]              <- if size is less than max_size
    for i in seq(0, n):
        for j in seq(0, m):
            a[i] = ...
    """
    alloc_c = p.forward(alloc_c)
    loop_c = get_enclosing_loop(p, alloc_c)
    accum_size = 1
    while True:
        try:
            if not isinstance(loop_c, pc.ForCursor):
                break
            if dep_set == None or loop_c.name() in dep_set:
                if (
                    isinstance(loop_c.hi(), LiteralCursor)
                    and accum_size * loop_c.hi().value() <= max_size
                ):
                    p = expand_dim(p, alloc_c, loop_c.hi().value(), loop_c.name())
                    accum_size = accum_size * loop_c.hi().value()
                    if lift:
                        p = lift_alloc(p, alloc_c)
            loop_c = loop_c.parent()
        except:
            break
    return p

This way, we can more concisely express the intent of our transformation (lifting out out_tile until it spans 4 tile registers) in our schedule:

# lift out_tile to span 4 tile (4x4) registers
p = autolift_alloc(p, out_tile, max_size=4 * 4 * 4, dep_set=["ioi","ii","ji"])

This yields the following code when we print p:

def exo_conv1d_tile_lt_kw(data: i32[4, 16] @ DRAM,
                       kernels: i32[16, 4, 4] @ DRAM,
                       out: i32[16, 16] @ DRAM):
 for ioo in seq(0, 1):
     for jo in seq(0, 4):
         out_tile: i32[4, 4, 4] @ DRAM
         for ioi in seq(0, 4):
             for ii in seq(0, 4):
                 for ji in seq(0, 4):
                     out_tile[ioi, ii, ji] = 0.0
                     for c in seq(0, 4):
                         for r in seq(0, 4):
                             y: i32 @ DRAM
                             if ji + r + 4 * jo < 16:
                                 y = data[c, ji + r + 4 * jo]
                             else:
                                 y = 0
                             out_tile[ioi, ii,
                                      ji] += kernels[ii + 4 * ioi +
                                                     16 * ioo, c, r] * y
                     out[ii + 4 * ioi + 16 * ioo,
                         ji + 4 * jo] = out_tile[ioi, ii, ji]

Next, we want to reorder the loops so that ioi, ii, and ji are on the inside of c: for each channel c, we are doing 4 (ioi) matrix multiplications on each tile (ii x ji). Currently they are in the opposite order. Exo won't let us reorder nested loops with other statements in the way, however. Indeed, here it would be wrong to simply swap ji and c because we set out_tile based on the index given by ji. So first, we should split this statement, as well as the storing into its own loop - fission:

# Block the zero initialization and store blocks
p = fission_as_much_as_possible(p, body)
p = fission_as_much_as_possible(p, body[0])

Once again, we've used a new helper function, fission_as_much_as_possible, which applies fission until Exo complains that it is invalid to do so.

Now we're ready to do the reordering:

# Reorder c loop to the top
p = lift_scope_n(p, c_loop, 3)

We see that the c loop now encloses ioi, ii, and ji:

def exo_conv1d_tile_lt_kw(data: i32[4, 16] @ DRAM,
                       kernels: i32[16, 4, 4] @ DRAM,
                       out: i32[16, 16] @ DRAM):
 for ioo in seq(0, 1):
     for jo in seq(0, 4):
         out_tile: i32[4, 4, 4] @ DRAM
         for ioi in seq(0, 4):
             for ii in seq(0, 4):
                 for ji in seq(0, 4):
                     out_tile[ioi, ii, ji] = 0.0
         for c in seq(0, 4):
             for ioi in seq(0, 4):
                 for ii in seq(0, 4):
                     for ji in seq(0, 4):
                         for r in seq(0, 4):
                             y: i32 @ DRAM
                             if ji + r + 4 * jo < 4:
                                 y = data[c, ji + r + 4 * jo]
                             else:
                                 y = 0
                             out_tile[ioi, ii,
                                      ji] += kernels[ii + 4 * ioi +
                                                     16 * ioo, c, r] * y
         for ioi in seq(0, 4):
             for ii in seq(0, 4):
                 for ji in seq(0, 4):
                     out[ii + 4 * ioi + 16 * ioo,
                         ji + 4 * jo] = out_tile[ioi, ii, ji]

Our next step is to apply the im2col transformation, where we separate out the setting of y into its own loop nest, making y a large buffer in the process. We can express this through applying fission between setting y and doing the multiply-accumulate, then lifting the y allocation up the loop nest. Afterwards, we stage the kernel and data matrices into new buffers just like with the output. We used most of these constructs to stage out, so we'll just show the whole block of the schedule:

# Stage y
p = autolift_alloc(p, y_alloc, max_size=4 * 4, dep_set=["r","ji"])
p = lift_alloc(p, y_alloc, n_lifts=2)

# Fission the initialization loop and remove redundant loops
p = fission_as_much_as_possible(p, y_assign.parent())
p = remove_redundant_loops(p, y_assign.parent(), num=2)

# Stage kernels to kernel_tile and y to data_tile
ii_loop = p.forward(c_loop).body()[2].body()[0]
p, (kernel_alloc, _, _, _) = auto_stage_mem(
    p, ii_loop, "kernels", "kernel_tile", rc=True
)
p = simplify(expand_dim(p, kernel_alloc, 4, ii_loop.parent().name()))
p = lift_alloc(p, kernel_alloc)
p, (data_alloc, _, _, _) = auto_stage_mem(
    p, ii_loop.parent(), "y", "data_tile", rc=True

Now im2col and matmul are in their own distinct loop nests, as we'd expect:

def exo_conv1d_tile_lt_kw(data: i32[4, 16] @ DRAM,
                       kernels: i32[16, 4, 4] @ DRAM,
                       out: i32[16, 16] @ DRAM):
 for ioo in seq(0, 1):
     for jo in seq(0, 4):
         out_tile: i32[4, 4, 4] @ DRAM
         for ioi in seq(0, 4):
             for ii in seq(0, 4):
                 for ji in seq(0, 4):
                     out_tile[ioi, ii, ji] = 0.0
         for c in seq(0, 4):
             y: i32[4, 4] @ DRAM
             for ji in seq(0, 4):
                 for r in seq(0, 4):
                     if ji + r + 4 * jo < 16:
                         y[ji, r] = data[c, ji + r + 4 * jo]
                     else:
                         y[ji, r] = 0
             kernel_tile: i32[4, 4, 4] @ DRAM
             data_tile: i32[4, 4] @ DRAM
             for i0 in seq(0, 4):
                 for i1 in seq(0, 4):
                     data_tile[i0, i1] = y[i0, i1]
             for ioi in seq(0, 4):
                 for i0 in seq(0, 4):
                     for i1 in seq(0, 4):
                         kernel_tile[ioi, i0,
                                     i1] = kernels[i0 + 4 * ioi + 16 * ioo,
                                                   c, i1]
                 for ii in seq(0, 4):
                     for ji in seq(0, 4):
                         for r in seq(0, 4):
                             out_tile[ioi, ii,
                                      ji] += kernel_tile[ioi, ii,
                                                         r] * data_tile[ji,
                                                                        r]
         for ioi in seq(0, 4):
             for ii in seq(0, 4):
                 for ji in seq(0, 4):
                     out[ii + 4 * ioi + 16 * ioo,
                         ji + 4 * jo] = out_tile[ioi, ii, ji]

We've just exposed several opportunities to offload work to our accelerator. The loop nests to load, multiply, and store data_tile, kernel_tile and out_tile correspond nicely with the behavior of the RVM instructions. But how do we express this equivalence to Exo?

When writing by hand in C, we ripped out scalar code and replaced it with our special instructions as inline assembly, trusting that they did the same thing. For example, we got rid of the out += data * kernels statements and in the end replaced it with mmasa.w, the RVM matmul instruction. So far we've been able to express everything else we wrote in C equivalently, but inline assembly simply doesn't exist in Exo.

We could use some workaround like implementing a compiler backend that detects places in the C code to offload, selecting the appropriate RVM instruction. This could work, but it cripples the utility of the schedule as a "record" for the program's optimizations: the full performance picture is dependent on the behavior of this compiler, and the instructions it selects!

Fortunately, Exo offers a clever, and unique solution to this problem, one that nicely encapsulates all the behavior in the schedule itself. The Exo programmer themselves gives a definition of their hardware instructions as any other generic procedure, with standard scalar operations. For example, here is how we define our mmasa.w instruction:

@instr('asm volatile("mmasa.w "{md_int}", "{ms1_int}", "{ms2_int});')
def rvm_mmasa(
   md: [i32][4, 4] @ RVM_TILE, ms1: [i32][4, 4] @ RVM_TILE,
   ms2: [i32][4, 4] @ RVM_TILE
):
   assert stride(md, 1) == 1
   assert stride(ms1, 1) == 1
   assert stride(ms2, 1) == 1
   for i in seq(0, 4):
      for j in seq(0, 4):
            for k in seq(0, 4):
               md[i, j] += ms2[i, k] * ms1[j, k]

The body of the procedure is nothing special: it is simply the Exo code for doing matrix multiply with scalar instructions. The key is in the @instr decorator we gave it, and the replace() directive in Exo.

replace() takes as arguments a cursor to some fragment of code inside a procedure, and a procedure whose body the fragment will be matched against. If Exo succeeds in unifying (i.e. pattern matching) the fragment with the body, then it will replace the fragment with a call to that procedure, automatically determining the correct arguments. You may be able to see where this is going: in the case of an instruction like we defined above, the body of the procedure is acting as a sort of specification for the instruction, encoding the semantics in a way that allows Exo to reason about when an offload is sound.

This seems like magic! We've managed to express the behavior of our accelerator inside Python. We can take arbitrary pieces of code in Exo and reason about if it's safe to offload some instruction there. No compiler backend needed. But the result is just procedure calls in the Exo language. When we go to compile it to C, how do we actually generate code which is appropriate for the accelerator?

This is where the @instr decorator comes in. The string we provided is a snippet of arbitrary C code with holes in it. When Exo goes to compile the call to the corresponding procedure, it pastes this piece of C code, filling in the holes with the names of the arguments in the compiled C code. For example, if the Exo array out is passed to the store instruction rvm_mst, then the C code will be an inline assembly which is passed the C array out.

You may be wondering how we are dealing with the custom memory that our accelerator supports. Indeed, we have completely glossed over this detail, but it's not valid to offload an operator if the array passed to the Exo procedure is in main memory, while the accelerator expects it to reside in some kind of scratchpad or register. We have to manually orchestrate this data movement. Besides, it's not clear how we'd express an accelerator-specific memory in C code in a generic manner.

Once again, Exo has a solution for this specific problem. You may have noticed the @ RVM_TILE annotations in the above procedure definition. This is actually a custom class representing our accelerator's tile register memory. Once again, it is defined by us, the programmer:

class RVM_TILE(StaticMemory):
 NUM_RVM_TILES = 8
 StaticMemory.init_state(NUM_RVM_TILES)
 tile_dict = {}

 ...

 @classmethod
 def alloc(cls, new_name, prim_type, shape, srcinfo):
     if not (shape[0].isdecimal() and int(shape[0]) == 4):
         raise MemGenError("Number of tile rows must be 4.")
     if not (shape[1].isdecimal() and int(shape[1]) == 4):
         raise MemGenError("Number of tile columns must be 4.")

     tile_num = cls.find_free_chunk()
     cls.mark(tile_num)
     cls.tile_dict[new_name] = tile_num
     return f'#define {new_name} "m{7-tile_num}"'

 @classmethod
 def free(cls, new_name, prim_type, shape, srcinfo):
     tile_num = cls.tile_dict[new_name]
     del cls.tile_dict[new_name]
     cls.unmark(tile_num)
     return f"#undef {new_name}"

The effect of this class is two-fold:

  • The annotation allows programmers to specify where a buffer is located. The class corresponding to the allocation is then used to describe how to lower operations on the memory to C code. Exo does not verify the consistency of these annotations during the scheduling phase, and you can even freely modify which memory is selected for some buffer with the directive set_memory. However, since the methods in the class also define what operations are allowed on the memory, that restricts what Exo programs will compile successfully. As a result, Exo can check in the backend that the code does not violate the semantics of the memory annotations. For example, load(), meaning a scalar load like A[0], is not defined for RVM_TILE, because we have no sense of what it means to load a single element from a tile register. With only alloc() and free() defined, we can only use RVM_TILE to invoke the special instruction procedures rvm_mmasa, rvm_mld, rvm_mst, necessitating that we orchestrate the proper data movement.

  • The implementations of these methods offer complete control for the accelerator code generation. For example, what we have in alloc() and free() is essentially a trivial register allocator: a free list is maintained in the parent class StaticMemory, and every alloc() takes a new free register, while free() puts it back. Spilling is not handled. The return value of these methods is yet another C fragment which Exo pastes when it compiles the program. Our solution here is on the hacky side, but it demonstrates the flexibility we have: we "allocate" a register by #define-ing the variable name to the register we selected. This macro gets copy-pasted into the inline assembly we use for our instructions.

Let's take a step back and return to our convolution routine. We had scheduled things to line up nicely with the structure of the RVM instructions. So the next step is to tell Exo to offload, using the constructs we discussed:

# Set adequate memories
p = set_memory(p, y_alloc, DRAM_STATIC)
p = set_memory(p, out_tile, RVM_TILE)
p = set_memory(p, kernel_alloc, RVM_TILE)
p = set_memory(p, data_alloc, RVM_TILE)

# Replace inner loops to calls to RVM instructions
p = replace_all(p, [rvm_mzero, rvm_mst, rvm_mld, rvm_mmasa])

The replace_all we used here is a wrapper around replace which actually finds all the fragments matching the provided instructions and replaces them accordingly, rather than requiring it to be passed explicitly.

The resulting Exo code has all offloadable sections replaced by calls to the functions we wrote earlier. Once again, from the Exo perspective, these are just calls to other Exo functions, but they carry a special meaning during code generation which allows them to represent our assembly instructions.

def exo_conv1d_tile_lt_kw(data: i32[4, 16] @ DRAM,
                       kernels: i32[16, 4, 4] @ DRAM,
                       out: i32[16, 16] @ DRAM):
 for ioo in seq(0, 1):
     for jo in seq(0, 4):
         out_tile: i32[4, 4, 4] @ RVM_TILE
         for ioi in seq(0, 4):
             rvm_mzero(out_tile[ioi, 0:4, 0:4])
         for c in seq(0, 4):
             y: i32[4, 4] @ DRAM_STATIC
             for ji in seq(0, 4):
                 for r in seq(0, 4):
                     if ji + r + 4 * jo < 4:
                         y[ji, r] = data[c, ji + r + 4 * jo]
                     else:
                         y[ji, r] = 0
             kernel_tile: i32[4, 4, 4] @ RVM_TILE
             data_tile: i32[4, 4] @ RVM_TILE
             rvm_mld(data_tile[0:4, 0:4], y[0:4, 0:4])
             for ioi in seq(0, 4):
                 rvm_mld(
                     kernel_tile[ioi, 0:4, 0:4],
                     kernels[4 * ioi + 16 * ioo:4 + 4 * ioi + 16 * ioo, c,
                             0:4])
                 rvm_mmasa(out_tile[ioi, 0:4, 0:4], data_tile[0:4, 0:4],
                           kernel_tile[ioi, 0:4, 0:4])
         for ioi in seq(0, 4):
             rvm_mst(
                 out_tile[ioi, 0:4, 0:4],
                 out[4 * ioi + 16 * ioo:4 + 4 * ioi + 16 * ioo,
                     4 * jo:4 + 4 * jo])

For our final transformations, we'd like to unroll each of the ioi loops, and allocate 4 different out_tile s, rather than having only one with an extra dimension. Exo of course provides the unroll_loop directive for loops, but also a directive to "unroll" a buffer: replace an allocation for a constant size n on a given dimension with n buffers without that dimension. We utilize these directives below.

# Clean up
p = unroll_loop(p, "ioi")
p = unroll_loop(p, "ioi")
p = unroll_loop(p, "ioi")
p = simplify(p)
p = unroll_buffer(p, kernel_alloc, 0)
p = reuse_buffer(p, "kernel_tile_0: _", "kernel_tile_3: _")

We can now compile our full Exo program to C, and see that the result is quite similar to the one we wrote by hand:

// exo_conv1d_tile_lt_kw(
//     data : i32[4, 16] @DRAM,
//     kernels : i32[16, 4, 4] @DRAM,
//     out : i32[16, 16] @DRAM
// )
void exo_conv1d_tile_lt_kw( void *ctxt, const int32_t* data, const int32_t* kernels, int32_t* out ) {
for (int_fast32_t ioo = 0; ioo < 1; ioo++) {
for (int_fast32_t jo = 0; jo < 4; jo++) {
   #define out_tile_0 "m7"
   #define out_tile_1 "m6"
   #define out_tile_2 "m5"
   #define out_tile_3 "m4"
   asm volatile("mzero "out_tile_0);
   asm volatile("mzero "out_tile_1);
   asm volatile("mzero "out_tile_2);
   asm volatile("mzero "out_tile_3);
   for (int_fast32_t c = 0; c < 4; c++) {
      static int32_t y[4 * 4];
      for (int_fast32_t ji = 0; ji < 4; ji++) {
      for (int_fast32_t r = 0; r < 4; r++) {
         if (ji + r + 4 * jo < 16) {
            y[ji * 4 + r] = data[c * 16 + ji + r + 4 * jo];
         } else {
            y[ji * 4 + r] = ((int32_t) 0);
         }
      }
      }
      #define kernel_tile_0 "m3"
      #define kernel_tile_1 "m2"
      #define kernel_tile_2 "m1"
      #define data_tile "m0"
      asm volatile("mld.w "data_tile", (%1), %0" :: "r"(4*(((struct exo_win_2i32c){
         &y[0], { 4, 1 } }).strides[0])), "r"(&y[0]));
      asm volatile("mld.w "kernel_tile_0", (%1), %0" :: "r"(4*(((struct exo_win_2i32c){
         &kernels[(16 * ioo) * (16) + (c) * 4], { 16, 1 } }).strides[0])),
         "r"(&kernels[(16 * ioo) * (16) + (c) * 4]));
      asm volatile("mmasa.w "out_tile_0", "data_tile", "kernel_tile_0);
      asm volatile("mld.w "kernel_tile_1", (%1), %0" :: "r"(4*(((struct exo_win_2i32c){
         &kernels[(4 + 16 * ioo) * (16) + (c) * 4], { 16, 1 } }).strides[0])),
         "r"(&kernels[(4 + 16 * ioo) * (16) + (c) * 4]));
      asm volatile("mmasa.w "out_tile_1", "data_tile", "kernel_tile_1);
      #undef kernel_tile_1
      asm volatile("mld.w "kernel_tile_2", (%1), %0" :: "r"(4*(((struct exo_win_2i32c){
         &kernels[(8 + 16 * ioo) * (16) + (c) * 4], { 16, 1 } }).strides[0])),
         "r"(&kernels[(8 + 16 * ioo) * (16) + (c) * 4]));
      asm volatile("mmasa.w "out_tile_2", "data_tile", "kernel_tile_2);
      #undef kernel_tile_2
      asm volatile("mld.w "kernel_tile_0", (%1), %0" :: "r"(4*(((struct exo_win_2i32c){
         &kernels[(12 + 16 * ioo) * (16) + (c) * 4], { 16, 1 } }).strides[0])),
         "r"(&kernels[(12 + 16 * ioo) * (16) + (c) * 4]));
      asm volatile("mmasa.w "out_tile_3", "data_tile", "kernel_tile_0);
      #undef data_tile
      #undef kernel_tile_0
   }
   asm volatile("mst.w "out_tile_0", (%1), %0" :: "r"(4*(((struct exo_win_2i32){
     &out[(16 * ioo) * (16) + 4 * jo], { 16, 1 } }).strides[0])), "r"(&out[(16 * ioo) * (16) + 4 * jo]));
   #undef out_tile_0
   asm volatile("mst.w "out_tile_1", (%1), %0" :: "r"(4*(((struct exo_win_2i32){
     &out[(4 + 16 * ioo) * (16) + 4 * jo], { 16, 1 } }).strides[0])), "r"(&out[(4 + 16 * ioo) * (16) + 4 * jo]));
   #undef out_tile_1
   asm volatile("mst.w "out_tile_2", (%1), %0" :: "r"(4*(((struct exo_win_2i32){
     &out[(8 + 16 * ioo) * (16) + 4 * jo], { 16, 1 } }).strides[0])), "r"(&out[(8 + 16 * ioo) * (16) + 4 * jo]));
   #undef out_tile_2
   asm volatile("mst.w "out_tile_3", (%1), %0" :: "r"(4*(((struct exo_win_2i32){
     &out[(12 + 16 * ioo) * (16) + 4 * jo], { 16, 1 } }).strides[0])), "r"(&out[(12 + 16 * ioo) * (16) + 4 * jo]));
   #undef out_tile_3
}
}
}

That's it! The full Exo code corresponding to this demo can be found here.

One aspect of Exo which we did not get to see, but is worth noting is its static analysis system. Exo verifies that rewrites applied in a schedule are sound through an 'effect' analysis. As of today, this effect analysis mainly encapsulates when and where locations in memory are accessed, and less about the values at those locations in memory (although this is an active area of improvement). For example, broadly speaking, swapping two statements in the code is considered valid only when the effects of the two statements are invisible to each other - one does not use the result of the other, or they do not both write to the same location, etc. Similar rules are defined for reordering loops, loop fusion & fission, etc.

In our case, behind the scenes, this system has been verifying that we have not made an unsound step at any point in our schedule. This is a powerful guarantee: Assuming the original program is correct (which is much easier to audit because we've maintained the algorithm itself), then Exo guarantees that our optimized code is correct too.

Overall, Exo's explicit design around programming for custom hardware lent itself quite well to our use case for RVM. However, there was one area we found that it was slightly limited. Exo doesn't currently offer a way to control scalar code generation like you can with accelerator instructions. Normally, this is fine: C compilers are good at handling common cases of array access expressions (like A[i][j], which is really A[i*stride+j]), perfoming instruction selection for CPUs, etc.

But our study actually uncovered a case where it was slightly better (5% faster, to be precise) to write out all the array accesses by hand, in the C code. Essentially, instead of A[i][j], we'd maintain an offset A_ofs that we'd increment in the appropriate places and have A[A_ofs] instead. We skipped over this optimization in the first section, but we'll see it in action when discussing our next tool. Also, our report contains the full details for the interested reader.

In Exo, we ran into trouble trying to implement this, as we were trying to manipulate the low-level details of the C source. In Exo, array accesses are denoted abstractly with square brackets. The compiler lowers it down to an expression like A[i*stride+j] in C. We couldn't find a way to manipulate such details in our schedule, because it's invisible to the language.

With that said, the C code generation is an active area of improvement for the Exo developers. Work is being done to not only make the C code generation more optimized in general, but to investigate ways to give the programmer more control over such details.

Inspired by the current limitations in Exo, we turned to another tool for our study, OptiTrust, offering a more radical approach: source to source rewriting in C itself. Let's now see how OptiTrust worked for our use case.

OptiTrust

OptiTrust is a tool allowing for source-to-source transformations of C/C++ code. These transformations, similar to the scheduling rules we have seen previously, are driven by an OCaml script which manipulates parts of an AST corresponding to the program. Like Exo, the soundness of transformations offered by OptiTrust are verified.

We'll give a quick tour of using OptiTrust, once again through the example of our convolution routine. Since we're rewriting C, we don't need to write any special algorithm this time - the direct convolution program we started with in Section 1 will do! So let's look at the OCaml script which captures our schedule. The full script is available here. Some knowledge of OCaml is assumed.

let _ = Run.script_cpp (fun () ->
   (* tile according to size supported by accelerator *)
   !! Loop.tile (trm_int 4) ~index:"tile_i"
   ~iter:TileIterGlobal ~bound:TileDivides [cFor "i"];
   !! Loop.tile (trm_int 4) ~index:"tile_j"
   ~iter:TileIterGlobal ~bound:TileDivides [cFor "j"];
   !! Loop.reorder ~order:["tile_j"; "i"] [cFor "i"];
   (* tile again to have 4x compute *)
   !! Loop.tile (trm_int 4) ~index:"tile_i_hi"
   ~iter:TileIterGlobal ~bound:TileDivides [cFor "tile_i"];
   !! Loop.reorder ~order:["tile_j"; "tile_i"] [cFor "tile_i"];

   (* sum a tile at a time *)
   !! Loop.hoist_alloc_loop_list [1; 1; 1;] [cVarDef "sum"];
   (* zeroing *)
   !! Loop.fission ~nest_of:3 [tBefore; cFor "k"; occFirst];
   (* storing *)
   !! Loop.fission ~nest_of:3 [tAfter; cFor "k"; occFirst];

   let prefix_indices (idxs: (string * target) list) (pfx: string): unit =
   let _ = (List.map (fun idx ->
      Loop.rename_index (pfx ^ "_" ^ fst idx) (snd idx)) idxs) in () in
   !! prefix_indices [("tile_i", [cFor "tile_i"; occFirst]);
   ("i", [cFor "i"; occFirst]); ("j", [cFor "j"; occFirst])] "zero";
   !! prefix_indices [("tile_i", [cFor "tile_i"; occLast]);
   ("i", [cFor "i"; occLast]); ("j", [cFor "j"; occLast])] "st";

In OptiTrust, the schedule is written as a function with the signature unit -> unit (in our case, the variable we left anonymous with let _). Transformations are functions which cause side effects on the program's AST. The schedule itself is a composition of many transformation functions. The familiar loop tiling, reordering, fission, fusion, and others are present.

Note that the snippets of the schedule that follow are also part of this function definition, even though the let _ = is no longer shown.

Here, we are performing the same first steps of tiling the loop nest to match our accelerator parameters as we did with Exo. The !! starting each line allows OptiTrust to generate a trace showing the diff in the code caused by a rewrite. There's also some interactivity, as this is connected to the text editor, allowing the programmer to open the trace for the currently highlighted line. For example, here's the diff for the transformation Loop.reorder ~order:["tile_j"; "i"] [cFor "i"]; above:

40,41c40,41
<     for (int i = 0; i < 4; i++) {
<       for (int tile_j = 0; tile_j < exact_div(IW, 4); tile_j++) {
---
>     for (int tile_j = 0; tile_j < exact_div(IW, 4); tile_j++) {
>       for (int i = 0; i < 4; i++) {

To select fragments of the program to rewrite, OptiTrust uses a constraint system; all transformations take in a target representing the location(s) in the code desired to be transformed, which are lists of constraints matching against parts of the AST. For example, the target [cFor "i"; occLast ] matches the AST node for the last occurrence of a for loop statement with iteration index "i".

We can also freely compose these transformation functions to make our own, such as the convenience function prefix_indices we defined which allows for bulk renaming of loop indices.

Much of the primitive transformations we had in Exo also exist in OptiTrust, and work similarly. Using these, we were able to implement most of the optimized schedule we wrote using Exo. The report contains more details on the specifics of how we accomplished this.

But we were also able to take advantage of the ability to rewrite C directly to implement some new optimizations, such as the indexing case we couldn't express in Exo. Let's look at how we can write that with OptiTrust.

!! Variable.bind_multi ~dest:[cIf (); tBefore] "ofs" [sExpr "tile_j * 4 + j + r"];
Matrix_basic.elim_mindex [cMindex ~args:[[cVar "IC"]; [cVar "IW"]; [cVar "k"]; [cVar "ofs"]] ()];
Matrix_basic.elim_mindex [nbMulti; cMindex ~args:[[cInt 4]; [cInt 4]; [cVar "j"]; [cVar "r"]] ()];

First, we use the rewrites OptiTrust offers to extract the index expression used to access data into a new variable ofs. As OptiTrust represents these index expressions with a special macro by default (so that it can perform some reasoning on them), we use elim_mindex to expand them out. The diff given by OptiTrust is:

54,57c54,57
<             y[MINDEX2(4, 4, j, r)] = 0.f;
<             if (tile_j * 4 + j + r < IW) {
<               y[MINDEX2(4, 4, j, r)] =
<                   I[MINDEX2(IC, IW, k, tile_j * 4 + j + r)];
---
>             y[0 + j * 4 + r] = 0.f;
>             int ofs = tile_j * 4 + j + r;
>             if (ofs < IW) {
>               y[0 + j * 4 + r] = I[0 + k * IW + ofs];

We can see that this is essentially a common subexpression elimination.

!! loopize_decl [cFor "r"; occFirst] [cVarDef "ofs"];
Arith_basic.simplify [sExpr "0 + k * IW + ofs"];
Variable.bind "drow_ofs" [sExpr "k * IW + ofs"];
hoist_if [cIf ()] [cVarDef "drow_ofs"];
loopize_decl ~loop_ind_in:(find_var_in_current_ast "ofs") [cFor "r"; occFirst] [cVarDef "drow_ofs"];
loopize_expr [cFor "tile_j"] "data_base" [sExpr "tile_j * 4"];
loopize_expr [cFor   "k"] "data_row" [sExpr "k * IW"]

The first rewrite in this next block, loopize_decl, is actually not provided by OptiTrust, nor is it a composition of existing transformations. OptiTrust actually lets us define transformations which manipulate the AST directly. Since we're performing these rewrites in C, this gives us quite extensive control over the output.

Here, loopize_decl is a custom transformation which takes the initialization of the variable provided (the target [cVarDef "ofs"]), extracts the index of the for loop provided (the target [CFor "r"; occFirst]), and instead adds new statements that increment the variable where the index is incremented. The idea is to replace the arithmetic we're doing every time we compute ofs with statements that increment it in the appropriate places.

We apply this repeatedly for the other array accesses we need. Similarly, we define loopize_expr, which performs the same operation on some arbitrary expression, rather than specifically the variable initialization. In the end, we get this diff:

40a41
>     int data_base = 0;
49a51
>       int data_row = 0;
52a55,56
>           int ofs = data_base + j + 0;
>           int drow_ofs = data_row + ofs;
55d58
<             int ofs = tile_j * 4 + j + r;
57c60
<               y[0 + j * 4 + r] = I[0 + k * IW + ofs];
---
>               y[0 + j * 4 + r] = I[drow_ofs];
58a62,63
>             ofs = ofs + 1;
>             drow_ofs = drow_ofs + 1;
84a90
>         data_row = data_row + IW;
94a101
>     data_base = data_base + 4;

The striking thing about this optimization is that at a glance, it seems quite arbitrary. Indeed, there's nothing really logical about what we did here, if we're operating under the reasonable assumption that GCC/Clang wouldn't be so naive to compile the expression like y[j*TILE+r] inside a hot loop as a multiply/shift, followed by an add, then passed to a load. What we found while writing this is that it just so happens that the way we've written it out here has pushed Clang's optimizer into a path that lets it save 3 instructions in the assembly. This blows up to the 5% figure we measured due to being at the very inner part of the loop nest (the hottest path) in our code.

Cases like these are an interesting demonstration of how finicky high-performance code can be, to put it bluntly. That said, certainly not all high-performance code is like this - this only makes such a big difference due to the use of a simpler embedded, in-order CPU. Still, there is demand for these sorts of applications on embedded platforms, and being able to accommodate such low-level tweaking while bringing great productivity improvements is a great testament to OptiTrust.

So far, in our discussion of OptiTrust, we've omitted how exactly we take C code and offload it to the inline assembly instructions used to program our accelerator. It turns out that as of today, OptiTrust doesn't have support for inline assembly in the language. As a result, the one major problem we faced with OptiTrust was not being able to express offloading in the schedule, even though we could rewrite the program to expose these opportunities clearly. [3]

Inline assembly is a uniquely tricky C construct to incorporate for OptiTrust because the semantics are HW-dependent. Without knowledge of the instruction semantics, OptiTrust can't readily validate a schedule rewriting a program using inline assembly. The interested reader can refer to our report for a deeper discussion of this problem and how it may impact the design considerations for OptiTrust. Regardless, the issue of supporting custom HW like ours is also an active area of research for the OptiTrust team.

Honorable mentions: Halide & TVM

Halide is known for popularizing the concept of a user-scheduled language, proving to be highly effective at optimizing kernels for computer graphics specifically. TVM was in turn inspired by Halide and targeted kernels used in deep learning, adding other components such as a runtime for neural network execution. TVM has likewise proliferated in the deep learning space.

We didn't discuss these here, in large part because Exo addresses one of the major shortcomings of Halide & TVM for our use case, which is the reliance on a compiler backend for important hardware-specific scheduling decisions. To be more precise, Halide & TVM don't offer extensive control over instruction selection and memory management for custom accelerators. With Exo, these could be exposed to the schedule. For the well-established targets like SIMD CPU instructions and GPUs they were designed around, this was not a concern as mature compiler backends could be relied upon to produce good code. As RVM doesn't have a "compiler" at all, this was a bigger issue. Our report contains an in-depth discussion on our experience with TVM.

Conclusion

We hope this post has shed some light on some of the challenges in developing high-performance software for custom accelerators. Even though we started with a trivial convolution routine, we saw how having to tailor software towards the nuances of our hardware made for an unintuitive and unwieldy end result once thoroughly optimized.

Exo & OptiTrust both worked quite well to alleviate the productivity woes of writing such software. The decoupling of hardware-centric optimizations from functionality applied nicely to our needs. For both, we were able to recreate most of what we wrote by hand in the schedule. Although, with Exo we ran into some limitations regarding its code generation, since we needed to manipulate very low-level details around the C code it generated automatically. OptiTrust handled this case gracefully by allowing custom rewriting over C source, but we struggled to program our custom hardware in a way OptiTrust could incorporate. Both of these cases are active areas of development for the teams of these tools.

We encourage you to check out the tools discussed in this post, if you're interested in writing this kind of software, or are just curious and want to learn more:

This talk also provides a broader presentation of the problem of writing optimized software.