Skip to content
Snippets Groups Projects
model.py 3.46 KiB
Newer Older
from pathlib import Path
from typing import Union

import numpy as np
import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader
from .dataset import CIFAR
PathLike = Union[Path, str]


class MiniERA(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs = nn.Sequential(
            nn.Conv2d(3, 32, 3),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        self.fcs = nn.Sequential(nn.Linear(1600, 256), nn.ReLU(), nn.Linear(256, 5))
        self.softmax = nn.Softmax(1)

    def forward(self, input):
        outputs = self.convs(input)
        outputs = self.fcs(outputs.flatten(1, -1))
        return self.softmax(outputs)

    def load_legacy_hpvm_weights(self, prefix: Union[Path, str]):
        prefix = Path(prefix)
        # Load in model convolution weights
        count = 0
        for conv in self.convs:
            if not isinstance(conv, nn.Conv2d):
            weight_np = np.fromfile(
                prefix / f"conv2d_{count+1}_w.bin", dtype=np.float32
            )
Yifan Zhao's avatar
Yifan Zhao committed
            bias_np = np.fromfile(prefix / f"conv2d_{count+1}_b.bin", dtype=np.float32)
            conv.weight.data = torch.tensor(weight_np).reshape(conv.weight.shape)
            conv.bias.data = torch.tensor(bias_np).reshape(conv.bias.shape)
            count += 1
        # Load in model fc weights
        count = 0
        for linear in self.fcs:
            if not isinstance(linear, nn.Linear):
Yifan Zhao's avatar
Yifan Zhao committed
            weight_np = np.fromfile(prefix / f"dense_{count+1}_w.bin", dtype=np.float32)
            bias_np = np.fromfile(prefix / f"dense_{count+1}_b.bin", dtype=np.float32)
            cout, cin = linear.weight.shape
            linear.weight.data = torch.tensor(weight_np).reshape(cin, cout).T
            linear.bias.data = torch.tensor(bias_np).reshape(linear.bias.shape)
            count += 1
        return self


class MiniEraPL(pl.LightningModule):
    def __init__(self, dataset_path: PathLike):
        super().__init__()
        self.network = MiniERA()
        self.dataset_path = Path(dataset_path)

    def forward(self, image):
        prediction = self.network(image)
        return prediction

    @staticmethod
    def _get_loss(output, targets):
        from torch.nn.functional import cross_entropy

        return cross_entropy(output, targets)

    @staticmethod
    def _get_metric(output, targets):
        predicted = torch.argmax(output, 1)
        return (predicted == targets).sum().item() / len(targets)

    def validation_step(self, val_batch, batch_idx):
        images, target = val_batch
        prediction = self(images)
        loss = self._get_loss(prediction, target)
        accuracy = self._get_metric(prediction, target)
        self.log("val_loss", loss)
        self.log("val_acc", accuracy)
        return accuracy

    def test_step(self, test_batch, batch_idx):
        images, target = test_batch
        prediction = self(images)
        accuracy = self._get_metric(prediction, target)
        self.log("test_acc", accuracy)
        return accuracy

    def test_dataloader(self):
        dataset = CIFAR.from_file(
            self.dataset_path / "input.bin", self.dataset_path / "labels.bin"
        )
        return DataLoader(dataset, batch_size=128)