Implementing Autodiff: Static Dispatch and Compile Time Interfaces
Zig has an explicit approach to generics that is relevant to briefly introduce here. To implement a generic Container
type that can store a value of some generic type, we use a function that returns an anonymous struct:
pub fn Container(comptime T: type) type {
return struct {
value: T,
};
}
pub fn main() void {
const FloatContainer = Container(f32);
const f = FloatContainer{ .value = 1.2 };
const b = Container(bool){ .value = true };
}
The compiler knows what variants to generic at compiler time, hence the requirement that T
must be comptime-known. Perhaps not so interesting, but the Zig compiler is even more clever:
pub fn add(a: anytype, b: anytype) f32 {
return a.value + b.value;
}
pub fn main() void {
const FloatContainer = Container(f32);
const a = FloatContainer{ .value = 1.2 };
const b = FloatContainer{ .value = 4 };
std.debug.print("{}\n", .{add(a, b)});
}
Yes, this works! But no way this is safe, right? Actually, the Zig compiler will check for us that the attribute value
attribute exists on the parameter types at compile time. This also means that the compiler is clever enough to catch this error:
pub fn add(a: anytype, b: anytype) f16 {
return a.value + b.value;
}
pub fn main() void {
std.debug.print("{}\n", .{add(a, b)});
}
// src/main.zig:10:20: error: expected type 'f16', found 'f32'
// return a.value + b.value;
// ~~~~~~~~^~~~~~~~~
// src/main.zig:9:36: note: function return type declared here
// pub fn add(a: anytype, b: anytype) f16 {
// ^~~
So we should be good to go then, right? We can just access attributes and methods like forward
and backward
on the nodes in our computation graph during the backward pass and use the compiler to take care of the rest. We could, but this is not really a formal contract and I don’t know about you but I am not a compiler and in a few months I will not remember what “interface” I implicitly, dynamically, defined. Also, the Zig LSP will not be able to help during development among other reasons. Nonetheless, we could define a sort of compile time interface this way. Although, we could use this same idea to improve our approach a bit and define a contract at compile time.
pub fn Operation(
comptime T: type,
comptime forwardFn: *const fn (ptr: T) void,
comptime backwardFn: *const fn (ptr: T) void,
) type {
return struct {
ptr: T,
const Self = @This();
pub fn init(p: T) Self {
return .{ .ptr = p };
}
pub fn forward(self: Self) void {
forwardFn(self.ptr);
}
pub fn backward(self: Self) void {
backwardFn(self.ptr);
}
};
}
pub const AddOperation = struct {
operands: [2]f32,
const Self = @This();
pub fn forward(self: *const Self) void {
std.debug.print("forward(): a + b = {}\n", .{self.operands[0] + self.operands[1]});
}
pub fn backward(_: *const Self) void {
std.debug.print("backward(): grad = 1\n", .{});
}
};
pub fn runOperation(op: anytype) void {
op.forward();
op.backward();
}
pub fn main() void {
var add = AddOperation{ .operands = [_]f32{ 9, 1 } };
const addOp = Operation(*AddOperation, AddOperation.forward, AddOperation.backward).init(&add);
runOperation(addOp);
}
That is a bit better, we now have a way to turn things into Operation
types which can be treated uniformly. However, notice we are still stuck with the anytype
which highlights one of the more frustrating aspects of Zig. For all the type introspection abilities Zig provides, there is seemingly no way to get a reference to the return type of Operation
. This pattern is still viable, though, and is even used by the Zig standard library (std.io
). Still, not exactly what I meant when I said I wanted an interface. A possible solution to investigate next is using some more advanced casting concepts.