*ฅ^•ﻌ•^ฅ* ✨✨  HWisnu's blog  ✨✨ о ฅ^•ﻌ•^ฅ

Fast Matrix Multiplication program in C, Zig, Rust and Python

Introduction

With the weekend finally here, I've got some spare time to write about one of the benchmark programs I mentioned in my previous article on Memory Scope and Lifetimes. This program is particularly interesting, and I'm excited to dive into its details.

What is Matrix Multiplication??

Matrix multiplication is a fundamental operation in linear algebra and a crucial component of many scientific and machine learning applications. However, as matrix sizes increase, the standard O(n^3) algorithm can become a significant bottleneck.

This program implements a high-performance matrix multiplication algorithm in Zig, leveraging the language's low-level memory management and parallelization capabilities to achieve significant speedups over traditional implementations.

The Code!

When deciding which version of the code to feature in this article, I was faced with a dilemma. With multiple versions to choose from, it was impossible to include them all. However, I had an idea - why not showcase the one that delivers the best performance? And the winner is... the optimized Zig code!

Defining WorkItem and ThreadContext Structs

WorkItem struct with two fields:

ThreadContext struct with seven fields:

const std = @import("std");
const print = std.debug.print;
const tMilli = std.time.milliTimestamp;
const Prng = std.rand.DefaultPrng;
const Allocator = std.mem.Allocator;
const PageAlloc = std.heap.page_allocator;
const ArenaAlloc = std.heap.ArenaAllocator;

// Configuration
const T: usize = 64; // Tile size (adjust as necessary)
const V: usize = 32; // Vector size for SIMD operations (adjust based on target architecture)

const WorkItem = struct {
    i: usize, // Tile row index
    j: usize, // Tile column index
};

const ThreadContext = struct {
    A: []const f32,
    B: []const f32,
    C: []f32,
    M: usize,
    N: usize,
    K: usize,
    work_queue: *std.ArrayList(WorkItem),
    mutex: std.Thread.Mutex,
};

tiledMultiplyKernel Function

Performs the matrix multiplication using a tiled approach.

fn tiledMultiplyKernel(
    A: []const f32,
    B: []const f32,
    local_C: *[T][T]f32,
    N: usize,
    K: usize,
    i_start: usize,
    j_start: usize,
    k_start: usize,
    i_end: usize,
    j_end: usize,
    k_end: usize,
) void {
    var A_local: [T][T]f32 align(32) = undefined;
    var B_local: [T][T]f32 align(32) = undefined;

    // Load tiles of A and B with zero-padding
    for (0..T) |i| {
        @memset(&A_local[i], 0);
        if (i_start + i < i_end) {
            @memcpy(A_local[i][0..k_end - k_start], A[(i_start + i) * K + k_start .. (i_start + i) * K + k_end]);
        }
    }

    for (0..T) |k| {
        @memset(&B_local[k], 0);
        if (k_start + k < k_end) {
            @memcpy(B_local[k][0..j_end - j_start], B[(k_start + k) * N + j_start .. (k_start + k) * N + j_end]);
        }
    }

    // Compute the multiplication using SIMD
    for (0..T) |i| {
        var j: usize = 0;
        while (j < T) : (j += V) {
            var vec_sum: @Vector(V, f32) = @splat(0);
            for (0..T) |k| {
                const a_val = A_local[i][k];
                const a_vec = @as(@Vector(V, f32), @splat(a_val));
                const b_vec = @as(*align(1) const @Vector(V, f32), @ptrCast(B_local[k][j..].ptr)).*;
                vec_sum += a_vec * b_vec;
            }
            @as(*align(1) @Vector(V, f32), @ptrCast(local_C[i][j..].ptr)).* += vec_sum;
        }
    }
}

workerThread Function

Performs the matrix multiplication.

fn workerThread(context: *ThreadContext) void {
    var local_C: [T][T]f32 align(32) = undefined;

    while (true) {
        context.mutex.lock();
        const work_item = context.work_queue.popOrNull() orelse {
            context.mutex.unlock();
            break;
        };
        context.mutex.unlock();

        const i_start = work_item.i * T;
        const j_start = work_item.j * T;
        const i_end = @min(i_start + T, context.M);
        const j_end = @min(j_start + T, context.N);

        @memset(&local_C, undefined);

        var k: usize = 0;
        while (k < context.K) : (k += T) {
            const k_end = @min(k + T, context.K);
            tiledMultiplyKernel(
                context.A,
                context.B,
                &local_C,
                context.N,
                context.K,
                i_start,
                j_start,
                k,
                i_end,
                j_end,
                k_end,
            );
        }

        for (i_start..i_end) |i| {
            for (j_start..j_end) |j| {
                context.C[i * context.N + j] += local_C[i - i_start][j - j_start];
            }
        }
    }
}

tiledMatMul Function

Performs the matrix multiplication using a tiled approach.

pub fn tiledMatMul(
    A: []const f32,
    B: []const f32,
    C: []f32,
    M: usize,
    N: usize,
    K: usize,
) !void {
    const num_threads = try std.Thread.getCpuCount();
    const tiles_M = (M + T - 1) / T;
    const tiles_N = (N + T - 1) / T;

    var work_queue = try std.ArrayList(WorkItem).initCapacity(PageAlloc, tiles_M * tiles_N);
    defer work_queue.deinit();

    for (0..tiles_M) |i| {
        for (0..tiles_N) |j| {
            try work_queue.append(.{ .i = i, .j = j });
        }
    }

    var rng = Prng.init(@intCast(tMilli()));
    rng.random().shuffle(WorkItem, work_queue.items);

    var thread_pool = try std.ArrayList(std.Thread).initCapacity(PageAlloc, num_threads);
    defer thread_pool.deinit();

    var context = ThreadContext{
        .A = A,
        .B = B,
        .C = C,
        .M = M,
        .N = N,
        .K = K,
        .work_queue = &work_queue,
        .mutex = .{},
    };

    for (0..num_threads) |_| {
        try thread_pool.append(try std.Thread.spawn(.{}, workerThread, .{&context}));
    }

    for (thread_pool.items) |thread| {
        thread.join();
    }
}

calculateGflops Function

Calculates the GFLOPS of the matrix multiplication.

pub fn calculateGflops(M: usize, N: usize, K: usize, iterations: usize) !f64 {
    var arena = ArenaAlloc.init(PageAlloc);
    defer arena.deinit();
    const allocator = arena.allocator();

    const A = try allocator.alignedAlloc(f32, 32, M * K);
    defer allocator.free(A);
    const B = try allocator.alignedAlloc(f32, 32, K * N);
    defer allocator.free(B);
    const C = try allocator.alignedAlloc(f32, 32, M * N);
    defer allocator.free(C);

    for (A, 0..) |*val, i| {
        val.* = @as(f32, @floatFromInt(i % 10));
    }
    for (B, 0..) |*val, i| {
        val.* = @as(f32, @floatFromInt((i + 1) % 10));
    }

    try tiledMatMul(A, B, C, M, N, K);

    var timer = try std.time.Timer.start();
    for (0..iterations) |_| {
        try tiledMatMul(A, B, C, M, N, K);
    }
    const elapsed_ns = timer.read();

    const ops = 2 * M * N * K * iterations;
    const seconds = @as(f64, @floatFromInt(elapsed_ns)) / 1e9;
    const gflops =  @as(f64, @floatFromInt(ops)) / seconds / 1e9;

    return gflops;
}

main function

Defines the configuration, initializes the matrices, and performs the matrix multiplication using a tiled approach.

pub fn main() !void {
    const sizes = [_][3]usize{
        .{ 256, 256, 256 },
        .{ 512, 512, 512 },
        .{ 1024, 1024, 1024 },
        .{ 1024, 2048, 1024 },
        .{ 2048, 2048, 2048 },
    };

    const iterations = 10;

    print("T = {},\nV = {}\n", .{ T, V });

    for (sizes) |size| {
        const M = size[0];
        const N = size[1];
        const K = size[2];

        const gflops = try calculateGflops(M, N, K, iterations);

        print("Matrix size: {}x{}x{}, GFLOPS: {d:.2}\n", .{ M, N, K, gflops });
    }
}

Benchmark results

MULTI-THREADED

Language 256 512 1024 2048 Notes
Python 0.12 0.16 N/A N/A Baseline
-----
C -ori 22.69 53.77 77.69 77.90 188x; 335x speedup
C -simd 40.65 65.26 78.11 140.46 338x; 407x speedup
-----
Rust -std 27.95 37.58 65.97 109.28 232x; 234x speedup
Rust -Rayon 26.13 36.83 72.28 109.77 217x; 229x speedup
-----
Zig -ori 26.29 92.17 145.97 198.26 218x; 575x speedup
Zig -opt 51.64 135.39 166.41 193.86 429x; 845x speedup

Note: I only ran the Python version on 256 and 512 matrices, as it took an excessively long time to complete and my CPU was smoking hot with every threads at 100% flashing red everywhere.

In contrast, the other versions, which also utilized multiple threads, completed their tests in a matter of seconds. Unfortunately, I encountered issues with running the Codon version, but I'll be sure to update the results if I can get it working properly.

#c #high level #low level #matrix multiplication #programming #python #rust #zig