Skip to content
Snippets Groups Projects
model.py 1.99 KiB
Newer Older
from pathlib import Path
from typing import Union

import numpy as np
import torch
from torch.nn import Conv2d, Linear, MaxPool2d, Module, ReLU, Sequential, Softmax


class MiniERA(Module):
    def __init__(self):
        super().__init__()
        self.convs = Sequential(
            Conv2d(3, 32, 3),
            ReLU(),
            Conv2d(32, 32, 3),
            ReLU(),
            MaxPool2d(2, 2),
            Conv2d(32, 64, 3),
            ReLU(),
            Conv2d(64, 64, 3),
            ReLU(),
            MaxPool2d(2, 2),
        )
        self.fcs = Sequential(Linear(1600, 256), ReLU(), Linear(256, 5))
        self.softmax = Softmax(1)

    def forward(self, input):
        outputs = self.convs(input)
        outputs = self.fcs(outputs.flatten(1, -1))
        return self.softmax(outputs)

    def load_legacy_hpvm_weights(self, prefix: Union[Path, str]):
        prefix = Path(prefix)
        # Load in model convolution weights
        count = 0
        for conv in self.convs:
            if not isinstance(conv, Conv2d):
                continue
            weight_np = np.fromfile(
                prefix / f"conv2d_{count+1}_w.bin", dtype=np.float32
            )
Yifan Zhao's avatar
Yifan Zhao committed
            bias_np = np.fromfile(prefix / f"conv2d_{count+1}_b.bin", dtype=np.float32)
            conv.weight.data = torch.tensor(weight_np).reshape(conv.weight.shape)
            conv.bias.data = torch.tensor(bias_np).reshape(conv.bias.shape)
            count += 1
        # Load in model fc weights
        count = 0
        for linear in self.fcs:
            if not isinstance(linear, Linear):
                continue
Yifan Zhao's avatar
Yifan Zhao committed
            weight_np = np.fromfile(prefix / f"dense_{count+1}_w.bin", dtype=np.float32)
            bias_np = np.fromfile(prefix / f"dense_{count+1}_b.bin", dtype=np.float32)
            cout, cin = linear.weight.shape
            linear.weight.data = torch.tensor(weight_np).reshape(cin, cout).T
            linear.bias.data = torch.tensor(bias_np).reshape(linear.bias.shape)
            count += 1
        return self