summaryrefslogtreecommitdiff
path: root/src/matrix.zig
blob: eaf968521a4d929156bf102857643b8b51960868 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
const std = @import("std");
const ArrayList = std.ArrayList;

const matrix = struct {
    h: usize,
    w: usize,
    x: ArrayList(f32),
};

pub fn new(h: usize, w: usize, x: []f32) matrix {
    var m = matrix{
        .h = h,
        .w = w,
        .x = ArrayList(f32).init(std.heap.page_allocator),
    };

    var i: usize = 0;
    while (i < h * w) : (i += 1) {
        if (m.x.append(x[i])) |val| {} else |err| {}
    }

    return m;
}

pub fn dot(a: []f32, b: []f32, len: usize, step: usize) f32 {
    var x: f32 = 0;
    var j: usize = 0;
    var i: usize = len;
    var k: usize = 0;

    while (i > 0) : (i -= 1) {
        x += a[k] * b[j];
        k += 1;
        j += step;
    }

    return x;
}

pub fn multiply(a: matrix, b: matrix) matrix {
    var x = ArrayList(f32).init(std.heap.page_allocator);
    var i: usize = 0;
    var k: usize = 0;

    while (i < a.h) : (i += 1) {
        var j: usize = 0;
        while (j < b.w) : (j += 1) {
            if (x.append(dot(a.x.items[k..], b.x.items[j..], a.w, b.w))) |val| {} else |err| {}
        }
        k += a.w;
    }

    return new(a.h, b.w, x.items);
}