Skip to content
Snippets Groups Projects
miniera.py 1.96 KiB
Newer Older
  • Learn to ignore specific revisions
  • from pathlib import Path
    from typing import Union
    
    import numpy as np
    import torch
    from torch.nn import Conv2d, Linear, MaxPool2d, Module, ReLU, Sequential, Softmax
    
    
    class MiniERA(Module):
        def __init__(self):
            super().__init__()
            self.convs = Sequential(
                Conv2d(3, 32, 3),
                ReLU(),
                Conv2d(32, 32, 3),
                ReLU(),
                MaxPool2d(2, 2),
                Conv2d(32, 64, 3),
                ReLU(),
                Conv2d(64, 64, 3),
                ReLU(),
                MaxPool2d(2, 2),
            )
            self.fcs = Sequential(Linear(1600, 256), ReLU(), Linear(256, 5))
            self.softmax = Softmax(1)
    
        def forward(self, input):
            outputs = self.convs(input)
            outputs = self.fcs(outputs.flatten(1, -1))
            return self.softmax(outputs)
    
        def load_legacy_hpvm_weights(self, prefix: Union[Path, str]):
            prefix = Path(prefix)
            # Load in model convolution weights
            count = 0
            for conv in self.convs:
                if not isinstance(conv, Conv2d):
                    continue
    
    Yifan Zhao's avatar
    Yifan Zhao committed
                weight_np = np.fromfile(prefix / f"conv2d_{count+1}_w.bin", dtype=np.float32)
                bias_np = np.fromfile(prefix / f"conv2d_{count+1}_b.bin", dtype=np.float32)
    
                conv.weight.data = torch.tensor(weight_np).reshape(conv.weight.shape)
                conv.bias.data = torch.tensor(bias_np).reshape(conv.bias.shape)
                count += 1
            # Load in model fc weights
            count = 0
            for linear in self.fcs:
                if not isinstance(linear, Linear):
                    continue
    
    Yifan Zhao's avatar
    Yifan Zhao committed
                weight_np = np.fromfile(prefix / f"dense_{count+1}_w.bin", dtype=np.float32)
                bias_np = np.fromfile(prefix / f"dense_{count+1}_b.bin", dtype=np.float32)
    
                cout, cin = linear.weight.shape
                linear.weight.data = torch.tensor(weight_np).reshape(cin, cout).T
                linear.bias.data = torch.tensor(bias_np).reshape(linear.bias.shape)
                count += 1
            return self