Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from pathlib import Path
from typing import Union
import numpy as np
import torch
from torch.nn import Conv2d, Linear, MaxPool2d, Module, ReLU, Sequential, Softmax
class MiniERA(Module):
def __init__(self):
super().__init__()
self.convs = Sequential(
Conv2d(3, 32, 3),
ReLU(),
Conv2d(32, 32, 3),
ReLU(),
MaxPool2d(2, 2),
Conv2d(32, 64, 3),
ReLU(),
Conv2d(64, 64, 3),
ReLU(),
MaxPool2d(2, 2),
)
self.fcs = Sequential(Linear(1600, 256), ReLU(), Linear(256, 5))
self.softmax = Softmax(1)
def forward(self, input):
outputs = self.convs(input)
outputs = self.fcs(outputs.flatten(1, -1))
return self.softmax(outputs)
def load_legacy_hpvm_weights(self, prefix: Union[Path, str]):
prefix = Path(prefix)
# Load in model convolution weights
count = 0
for conv in self.convs:
if not isinstance(conv, Conv2d):
continue
weight_np = np.fromfile(
prefix / f"conv2d_{count+1}_w.bin", dtype=np.float32
)
bias_np = np.fromfile(prefix / f"conv2d_{count+1}_b.bin", dtype=np.float32)
conv.weight.data = torch.tensor(weight_np).reshape(conv.weight.shape)
conv.bias.data = torch.tensor(bias_np).reshape(conv.bias.shape)
count += 1
# Load in model fc weights
count = 0
for linear in self.fcs:
if not isinstance(linear, Linear):
continue
weight_np = np.fromfile(prefix / f"dense_{count+1}_w.bin", dtype=np.float32)
bias_np = np.fromfile(prefix / f"dense_{count+1}_b.bin", dtype=np.float32)
cout, cin = linear.weight.shape
linear.weight.data = torch.tensor(weight_np).reshape(cin, cout).T
linear.bias.data = torch.tensor(bias_np).reshape(linear.bias.shape)
count += 1
return self