Implementing Autodiff: First Attempts at Zig Abstractions

Designing Abstractions

Zig has no native language features for implementing interfaces which naturally lend themselves well for designing flexible abstractions in an extensible deep learning framework. In other languages, we could define an interface for an Operation or Module similar to PyTorch. An Operation might define forward and backward methods called during inference and backpropagation respectively as nodes in the computation graph. This gives us a few desirable properties I am looking for in my framework:

  1. Adding operations does not require modifying core logic. In other words, the Liskov Substitution Principle. Custom operations are very common in deep learning where we often want to define our own loss functions, activation functions, etc so we should design for extensibility.
  2. Computation graph management should be decoupled from the specific implementations of operations, interacting with them through standardized abstractions. In other words, we should make use of the Single Responsibility Principle to maintain good Separation of Concerns where possible. This should allow for more maintainable performance optimizations such as caching in the forward pass.

First Pass

The first, naive implementation of autograd used callbacks attached to a Value struct where Value was analogous to a scalar valued tensor. The approach looked something like this:

pub const Value = struct {
 const Self = @This();
 value: f64,
 grad: f64 = 0.0,
 children: ?[]*Value = null,
 backward: ?*const fn (*Self) void = null,
 // ...
}

Operations were responsible for modifying the operands and attaching their backward callback to the newly allocated intermediate. A simplified version of the implementation went something like this:

pub fn add(allocator: *const std.mem.Allocator, v1: *Value, v2: *Value) *Value {
 var children = allocator.alloc(*Value, 2);
 children[0] = v1;
 children[1] = v2;

 var out = Value.init(allocator, v1.value + v2.value);
 out.children = children;
 out.backward = add_backward; // attach backward function
 return out;
}

pub fn add_backward(v: *Value) void {
 v.children.?[0].grad += v.grad;
 v.children.?[1].grad += v.grad;
}

Since in zig methods are just functions that take a struct as the first argument, the distinction feels like merely syntactic sugar and we can leverage this design choice to dynamically attach the backward function as a method on the object. During the backward pass we simply had to call the backward method on the object.

pub fn backward(self: *Self) !void {
 // ...
 v.?.backward();
 // ...
}

This approach was quite clean compared to other common approaches in toy autodiff projects such as switching on an op flag. For example,

pub const Op = enum { ADD, SUB };

pub const Value = struct {
 const Self = @This();
 value: f64,
 grad: f64 = 0.0,
 children: ?[]*Value = null,
 op: ?Op = null,
 // ...
}

pub fn add(allocator: *const std.mem.Allocator, v1: *Value, v2: *Value) *Value {
 var children = allocator.alloc(*Value, 2);
 children[0] = v1;
 children[1] = v2;

 var out = Value.init(allocator, v1.value + v2.value);
 out.children = children;
 out.op = Op.ADD; // set Op flag
 return out;
}

pub fn backward(self: *Self) !void {
 // ...
 const symbol = switch (v.op.?) {
  .ADD => add_backward(v),
  .SUB => sub_backward(v),
  // ...
 };
 // ...
}

This pattern suffers many issues, not the least of which is the lack of extensibility. A possible argument in favor of this approach is simplicity since we can simply implement the backward logic directly in the switch (rather than call a function as shown above). If we commit to this pattern then we can implement (forward) operations as methods directly in the struct definition which may make things more readable at least for a set of core operations.

I briefly explored some other options but this first pass was intended to verify the feasibility and gauge the practicality of implementing autograd in Zig. This is about as far as I got in exploring design patterns for the scalar valued version of autograd and would not revisit this aspect of the framework until I began the tensor valued implementation.

Second Pass

The tensor valued design is still very much a work in progress as special is being taken to make room for performance optimizations and the entire design is subject to a few more refactors. For the first attempt at re-implementing autograd with tensors I went with the switch pattern. Did I not just say that that was a bad design? I did, but for some reason I kept seeing this pattern pop up in other autograd projects so I wondered if I was missing something. Spoiler: I was not.

In general, implementing N-dimensional tensors is a straightforward task even in lower level languages. However, this second pass was really about performance optimization since scalar valued computation graphs do not scale. Hence, implementing ND tensors was really a means to an end. It followed, then, that the tensor abstraction should be designed with performance in mind from the start. Unforunately, that makes implementation a less straightforward task to say the least. Due to the complexity that started to accrue in the tensor logic I decided to use the switch pattern for backpropogation in an attempt to minimize complexity at least temporarily. As anticipated, this resulted in quite the switch block that only grew with time. Autodiff was soon functional but when it came time to train even a simple neural network the cracks began to show. Naturally, we would need a few basic activation functions and loss functions all of which need to be backpropogated through and thus would require modifying both the Op enum and the backward switch. I tried to avoid this by implementing the loss function with tensor operations and ignore the fact that this incurs a performance cost in the backward pass.

Once I could train a basic linear layer it was time to revisit the design pattern discussion since it was clear the current approach did not satisfy core requirements.

Going Forward

Looking back, the callback implementation was actually quite clean but it too had some disadvantages. First, it couples the forward and backward logic. Since the forward pass is responsible for attaching the backward function to the tensor then we limit our injection points for customizability. Second, the functions are stateless. As much as I enjoy state management in functional programming I was not too keen on implementing generic state management protocols with support for unforseen complexity. A concrete example of this is saving context for the backward pass which allows to avoid recomputation during the calculation of some gradients. Aside from performance I quickly realized that shape mutation was a serious consideration to design for with backpropogation, but this is a topic for another time.

There are some creative ways to implement stateful operations using the callback approach where we could use a pattern analogous to closures in higher level languages but this seemed a bit hacky and still felt lacking anyways. I needed more options to consider, I needed to consider Zig interfaces. It is worth noting that Zig interfaces have always been on the table but were avoided since they add complexity. Wait, but Zig doesnt have interfaces right? No, not in the traditional sense. Zig interfaces are more a design pattern than anything else and actually cover a range of implementation patterns.