Fast Matrix Multiplication program in C, Zig, Rust and Python
Introduction
With the weekend finally here, I've got some spare time to write about one of the benchmark programs I mentioned in my previous article on Memory Scope and Lifetimes. This program is particularly interesting, and I'm excited to dive into its details.
What is Matrix Multiplication??
Matrix multiplication is a fundamental operation in linear algebra and a crucial component of many scientific and machine learning applications. However, as matrix sizes increase, the standard O(n^3) algorithm can become a significant bottleneck.
This program implements a high-performance matrix multiplication algorithm in Zig, leveraging the language's low-level memory management and parallelization capabilities to achieve significant speedups over traditional implementations.
The Code!
When deciding which version of the code to feature in this article, I was faced with a dilemma. With multiple versions to choose from, it was impossible to include them all. However, I had an idea - why not showcase the one that delivers the best performance? And the winner is... the optimized Zig code!
Defining WorkItem and ThreadContext Structs
WorkItem struct with two fields:
- i: the tile row index
- j: the tile column index
ThreadContext struct with seven fields:
- A: the matrix A
- B: the matrix B
- C: the matrix C
- M: the number of rows in matrix A
- N: the number of columns in matrix B
- K: the number of columns in matrix A and the number of rows in matrix B
- work_queue: a queue of work items
- mutex: a mutex for synchronizing access to the work queue
const std = @import("std");
const print = std.debug.print;
const tMilli = std.time.milliTimestamp;
const Prng = std.rand.DefaultPrng;
const Allocator = std.mem.Allocator;
const PageAlloc = std.heap.page_allocator;
const ArenaAlloc = std.heap.ArenaAllocator;
// Configuration
const T: usize = 64; // Tile size (adjust as necessary)
const V: usize = 32; // Vector size for SIMD operations (adjust based on target architecture)
const WorkItem = struct {
i: usize, // Tile row index
j: usize, // Tile column index
};
const ThreadContext = struct {
A: []const f32,
B: []const f32,
C: []f32,
M: usize,
N: usize,
K: usize,
work_queue: *std.ArrayList(WorkItem),
mutex: std.Thread.Mutex,
};
tiledMultiplyKernel Function
Performs the matrix multiplication using a tiled approach.
fn tiledMultiplyKernel(
A: []const f32,
B: []const f32,
local_C: *[T][T]f32,
N: usize,
K: usize,
i_start: usize,
j_start: usize,
k_start: usize,
i_end: usize,
j_end: usize,
k_end: usize,
) void {
var A_local: [T][T]f32 align(32) = undefined;
var B_local: [T][T]f32 align(32) = undefined;
// Load tiles of A and B with zero-padding
for (0..T) |i| {
@memset(&A_local[i], 0);
if (i_start + i < i_end) {
@memcpy(A_local[i][0..k_end - k_start], A[(i_start + i) * K + k_start .. (i_start + i) * K + k_end]);
}
}
for (0..T) |k| {
@memset(&B_local[k], 0);
if (k_start + k < k_end) {
@memcpy(B_local[k][0..j_end - j_start], B[(k_start + k) * N + j_start .. (k_start + k) * N + j_end]);
}
}
// Compute the multiplication using SIMD
for (0..T) |i| {
var j: usize = 0;
while (j < T) : (j += V) {
var vec_sum: @Vector(V, f32) = @splat(0);
for (0..T) |k| {
const a_val = A_local[i][k];
const a_vec = @as(@Vector(V, f32), @splat(a_val));
const b_vec = @as(*align(1) const @Vector(V, f32), @ptrCast(B_local[k][j..].ptr)).*;
vec_sum += a_vec * b_vec;
}
@as(*align(1) @Vector(V, f32), @ptrCast(local_C[i][j..].ptr)).* += vec_sum;
}
}
}
workerThread Function
Performs the matrix multiplication.
fn workerThread(context: *ThreadContext) void {
var local_C: [T][T]f32 align(32) = undefined;
while (true) {
context.mutex.lock();
const work_item = context.work_queue.popOrNull() orelse {
context.mutex.unlock();
break;
};
context.mutex.unlock();
const i_start = work_item.i * T;
const j_start = work_item.j * T;
const i_end = @min(i_start + T, context.M);
const j_end = @min(j_start + T, context.N);
@memset(&local_C, undefined);
var k: usize = 0;
while (k < context.K) : (k += T) {
const k_end = @min(k + T, context.K);
tiledMultiplyKernel(
context.A,
context.B,
&local_C,
context.N,
context.K,
i_start,
j_start,
k,
i_end,
j_end,
k_end,
);
}
for (i_start..i_end) |i| {
for (j_start..j_end) |j| {
context.C[i * context.N + j] += local_C[i - i_start][j - j_start];
}
}
}
}
tiledMatMul Function
Performs the matrix multiplication using a tiled approach.
pub fn tiledMatMul(
A: []const f32,
B: []const f32,
C: []f32,
M: usize,
N: usize,
K: usize,
) !void {
const num_threads = try std.Thread.getCpuCount();
const tiles_M = (M + T - 1) / T;
const tiles_N = (N + T - 1) / T;
var work_queue = try std.ArrayList(WorkItem).initCapacity(PageAlloc, tiles_M * tiles_N);
defer work_queue.deinit();
for (0..tiles_M) |i| {
for (0..tiles_N) |j| {
try work_queue.append(.{ .i = i, .j = j });
}
}
var rng = Prng.init(@intCast(tMilli()));
rng.random().shuffle(WorkItem, work_queue.items);
var thread_pool = try std.ArrayList(std.Thread).initCapacity(PageAlloc, num_threads);
defer thread_pool.deinit();
var context = ThreadContext{
.A = A,
.B = B,
.C = C,
.M = M,
.N = N,
.K = K,
.work_queue = &work_queue,
.mutex = .{},
};
for (0..num_threads) |_| {
try thread_pool.append(try std.Thread.spawn(.{}, workerThread, .{&context}));
}
for (thread_pool.items) |thread| {
thread.join();
}
}
calculateGflops Function
Calculates the GFLOPS of the matrix multiplication.
pub fn calculateGflops(M: usize, N: usize, K: usize, iterations: usize) !f64 {
var arena = ArenaAlloc.init(PageAlloc);
defer arena.deinit();
const allocator = arena.allocator();
const A = try allocator.alignedAlloc(f32, 32, M * K);
defer allocator.free(A);
const B = try allocator.alignedAlloc(f32, 32, K * N);
defer allocator.free(B);
const C = try allocator.alignedAlloc(f32, 32, M * N);
defer allocator.free(C);
for (A, 0..) |*val, i| {
val.* = @as(f32, @floatFromInt(i % 10));
}
for (B, 0..) |*val, i| {
val.* = @as(f32, @floatFromInt((i + 1) % 10));
}
try tiledMatMul(A, B, C, M, N, K);
var timer = try std.time.Timer.start();
for (0..iterations) |_| {
try tiledMatMul(A, B, C, M, N, K);
}
const elapsed_ns = timer.read();
const ops = 2 * M * N * K * iterations;
const seconds = @as(f64, @floatFromInt(elapsed_ns)) / 1e9;
const gflops = @as(f64, @floatFromInt(ops)) / seconds / 1e9;
return gflops;
}
main function
Defines the configuration, initializes the matrices, and performs the matrix multiplication using a tiled approach.
pub fn main() !void {
const sizes = [_][3]usize{
.{ 256, 256, 256 },
.{ 512, 512, 512 },
.{ 1024, 1024, 1024 },
.{ 1024, 2048, 1024 },
.{ 2048, 2048, 2048 },
};
const iterations = 10;
print("T = {},\nV = {}\n", .{ T, V });
for (sizes) |size| {
const M = size[0];
const N = size[1];
const K = size[2];
const gflops = try calculateGflops(M, N, K, iterations);
print("Matrix size: {}x{}x{}, GFLOPS: {d:.2}\n", .{ M, N, K, gflops });
}
}
Benchmark results
MULTI-THREADED
Language | 256 | 512 | 1024 | 2048 | Notes |
---|---|---|---|---|---|
Python | 0.12 | 0.16 | N/A | N/A | Baseline |
----- | |||||
C -ori | 22.69 | 53.77 | 77.69 | 77.90 | 188x; 335x speedup |
C -simd | 40.65 | 65.26 | 78.11 | 140.46 | 338x; 407x speedup |
----- | |||||
Rust -std | 27.95 | 37.58 | 65.97 | 109.28 | 232x; 234x speedup |
Rust -Rayon | 26.13 | 36.83 | 72.28 | 109.77 | 217x; 229x speedup |
----- | |||||
Zig -ori | 26.29 | 92.17 | 145.97 | 198.26 | 218x; 575x speedup |
Zig -opt | 51.64 | 135.39 | 166.41 | 193.86 | 429x; 845x speedup |
Note: I only ran the Python version on 256 and 512 matrices, as it took an excessively long time to complete and my CPU was smoking hot with every threads at 100% flashing red everywhere.
In contrast, the other versions, which also utilized multiple threads, completed their tests in a matter of seconds. Unfortunately, I encountered issues with running the Codon version, but I'll be sure to update the results if I can get it working properly.