1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
|
const std = @import("std");
const ArrayList = std.ArrayList;
const matrix = struct {
h: usize,
w: usize,
x: ArrayList(f32),
};
pub fn new(h: usize, w: usize, x: []f32) matrix {
var m = matrix{
.h = h,
.w = w,
.x = ArrayList(f32).init(std.heap.page_allocator),
};
var i: usize = 0;
while (i < h * w) : (i += 1) {
if (m.x.append(x[i])) |val| {} else |err| {}
}
return m;
}
pub fn dot(a: []f32, b: []f32, len: usize, step: usize) f32 {
var x: f32 = 0;
var j: usize = 0;
var i: usize = len;
var k: usize = 0;
while (i > 0) : (i -= 1) {
x += a[k] * b[j];
k += 1;
j += step;
}
return x;
}
pub fn multiply(a: matrix, b: matrix) matrix {
var x = ArrayList(f32).init(std.heap.page_allocator);
var i: usize = 0;
var k: usize = 0;
while (i < a.h) : (i += 1) {
var j: usize = 0;
while (j < b.w) : (j += 1) {
if (x.append(dot(a.x.items[k..], b.x.items[j..], a.w, b.w))) |val| {} else |err| {}
}
k += a.w;
}
return new(a.h, b.w, x.items);
}
|