Skip to content
Snippets Groups Projects
miniera.py 1.96 KiB
from pathlib import Path
from typing import Union

import numpy as np
import torch
from torch.nn import Conv2d, Linear, MaxPool2d, Module, ReLU, Sequential, Softmax


class MiniERA(Module):
    def __init__(self):
        super().__init__()
        self.convs = Sequential(
            Conv2d(3, 32, 3),
            ReLU(),
            Conv2d(32, 32, 3),
            ReLU(),
            MaxPool2d(2, 2),
            Conv2d(32, 64, 3),
            ReLU(),
            Conv2d(64, 64, 3),
            ReLU(),
            MaxPool2d(2, 2),
        )
        self.fcs = Sequential(Linear(1600, 256), ReLU(), Linear(256, 5))
        self.softmax = Softmax(1)

    def forward(self, input):
        outputs = self.convs(input)
        outputs = self.fcs(outputs.flatten(1, -1))
        return self.softmax(outputs)

    def load_legacy_hpvm_weights(self, prefix: Union[Path, str]):
        prefix = Path(prefix)
        # Load in model convolution weights
        count = 0
        for conv in self.convs:
            if not isinstance(conv, Conv2d):
                continue
            weight_np = np.fromfile(prefix / f"conv2d_{count+1}_w.bin", dtype=np.float32)
            bias_np = np.fromfile(prefix / f"conv2d_{count+1}_b.bin", dtype=np.float32)
            conv.weight.data = torch.tensor(weight_np).reshape(conv.weight.shape)
            conv.bias.data = torch.tensor(bias_np).reshape(conv.bias.shape)
            count += 1
        # Load in model fc weights
        count = 0
        for linear in self.fcs:
            if not isinstance(linear, Linear):
                continue
            weight_np = np.fromfile(prefix / f"dense_{count+1}_w.bin", dtype=np.float32)
            bias_np = np.fromfile(prefix / f"dense_{count+1}_b.bin", dtype=np.float32)
            cout, cin = linear.weight.shape
            linear.weight.data = torch.tensor(weight_np).reshape(cin, cout).T
            linear.bias.data = torch.tensor(bias_np).reshape(linear.bias.shape)
            count += 1
        return self