from pathlib import Path from typing import Union import numpy as np import pytorch_lightning as pl import torch from torch import nn from torch.utils.data import DataLoader from .dataset import CIFAR PathLike = Union[Path, str] class MiniERA(nn.Module): def __init__(self): super().__init__() self.convs = nn.Sequential( nn.Conv2d(3, 32, 3), nn.ReLU(), nn.Conv2d(32, 32, 3), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(32, 64, 3), nn.ReLU(), nn.Conv2d(64, 64, 3), nn.ReLU(), nn.MaxPool2d(2, 2), ) self.fcs = nn.Sequential(nn.Linear(1600, 256), nn.ReLU(), nn.Linear(256, 5)) self.softmax = nn.Softmax(1) def forward(self, input): outputs = self.convs(input) outputs = self.fcs(outputs.flatten(1, -1)) return self.softmax(outputs) def load_legacy_hpvm_weights(self, prefix: Union[Path, str]): prefix = Path(prefix) # Load in model convolution weights count = 0 for conv in self.convs: if not isinstance(conv, nn.Conv2d): continue weight_np = np.fromfile( prefix / f"conv2d_{count+1}_w.bin", dtype=np.float32 ) bias_np = np.fromfile(prefix / f"conv2d_{count+1}_b.bin", dtype=np.float32) conv.weight.data = torch.tensor(weight_np).reshape(conv.weight.shape) conv.bias.data = torch.tensor(bias_np).reshape(conv.bias.shape) count += 1 # Load in model fc weights count = 0 for linear in self.fcs: if not isinstance(linear, nn.Linear): continue weight_np = np.fromfile(prefix / f"dense_{count+1}_w.bin", dtype=np.float32) bias_np = np.fromfile(prefix / f"dense_{count+1}_b.bin", dtype=np.float32) cout, cin = linear.weight.shape linear.weight.data = torch.tensor(weight_np).reshape(cin, cout).T linear.bias.data = torch.tensor(bias_np).reshape(linear.bias.shape) count += 1 return self class MiniEraPL(pl.LightningModule): def __init__(self, dataset_path: PathLike): super().__init__() self.network = MiniERA() self.dataset_path = Path(dataset_path) def forward(self, image): prediction = self.network(image) return prediction @staticmethod def _get_loss(output, targets): from torch.nn.functional import cross_entropy return cross_entropy(output, targets) @staticmethod def _get_metric(output, targets): predicted = torch.argmax(output, 1) return (predicted == targets).sum().item() / len(targets) def validation_step(self, val_batch, batch_idx): images, target = val_batch prediction = self(images) loss = self._get_loss(prediction, target) accuracy = self._get_metric(prediction, target) self.log("val_loss", loss) self.log("val_acc", accuracy) return accuracy def test_step(self, test_batch, batch_idx): images, target = test_batch prediction = self(images) accuracy = self._get_metric(prediction, target) self.log("test_acc", accuracy) return accuracy def test_dataloader(self): dataset = CIFAR.from_file( self.dataset_path / "input.bin", self.dataset_path / "labels.bin" ) return DataLoader(dataset, batch_size=128)