-
Yifan Zhao authoredYifan Zhao authored
miniera.py 1.96 KiB
from pathlib import Path
from typing import Union
import numpy as np
import torch
from torch.nn import Conv2d, Linear, MaxPool2d, Module, ReLU, Sequential, Softmax
class MiniERA(Module):
def __init__(self):
super().__init__()
self.convs = Sequential(
Conv2d(3, 32, 3),
ReLU(),
Conv2d(32, 32, 3),
ReLU(),
MaxPool2d(2, 2),
Conv2d(32, 64, 3),
ReLU(),
Conv2d(64, 64, 3),
ReLU(),
MaxPool2d(2, 2),
)
self.fcs = Sequential(Linear(1600, 256), ReLU(), Linear(256, 5))
self.softmax = Softmax(1)
def forward(self, input):
outputs = self.convs(input)
outputs = self.fcs(outputs.flatten(1, -1))
return self.softmax(outputs)
def load_legacy_hpvm_weights(self, prefix: Union[Path, str]):
prefix = Path(prefix)
# Load in model convolution weights
count = 0
for conv in self.convs:
if not isinstance(conv, Conv2d):
continue
weight_np = np.fromfile(prefix / f"conv2d_{count+1}_w.bin", dtype=np.float32)
bias_np = np.fromfile(prefix / f"conv2d_{count+1}_b.bin", dtype=np.float32)
conv.weight.data = torch.tensor(weight_np).reshape(conv.weight.shape)
conv.bias.data = torch.tensor(bias_np).reshape(conv.bias.shape)
count += 1
# Load in model fc weights
count = 0
for linear in self.fcs:
if not isinstance(linear, Linear):
continue
weight_np = np.fromfile(prefix / f"dense_{count+1}_w.bin", dtype=np.float32)
bias_np = np.fromfile(prefix / f"dense_{count+1}_b.bin", dtype=np.float32)
cout, cin = linear.weight.shape
linear.weight.data = torch.tensor(weight_np).reshape(cin, cout).T
linear.bias.data = torch.tensor(bias_np).reshape(linear.bias.shape)
count += 1
return self