Newer
Older
from pathlib import Path
from typing import Union
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
PathLike = Union[Path, str]
class MiniERA(nn.Module):
def __init__(self):
super().__init__()
self.convs = nn.Sequential(
nn.Conv2d(3, 32, 3),
nn.ReLU(),
nn.Conv2d(32, 32, 3),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, 3),
nn.ReLU(),
nn.Conv2d(64, 64, 3),
nn.ReLU(),
nn.MaxPool2d(2, 2),
self.fcs = nn.Sequential(nn.Linear(1600, 256), nn.ReLU(), nn.Linear(256, 5))
self.softmax = nn.Softmax(1)
def forward(self, input):
outputs = self.convs(input)
outputs = self.fcs(outputs.flatten(1, -1))
return self.softmax(outputs)
def load_legacy_hpvm_weights(self, prefix: Union[Path, str]):
prefix = Path(prefix)
# Load in model convolution weights
count = 0
for conv in self.convs:
weight_np = np.fromfile(
prefix / f"conv2d_{count+1}_w.bin", dtype=np.float32
)
bias_np = np.fromfile(prefix / f"conv2d_{count+1}_b.bin", dtype=np.float32)
conv.weight.data = torch.tensor(weight_np).reshape(conv.weight.shape)
conv.bias.data = torch.tensor(bias_np).reshape(conv.bias.shape)
count += 1
# Load in model fc weights
count = 0
for linear in self.fcs:
weight_np = np.fromfile(prefix / f"dense_{count+1}_w.bin", dtype=np.float32)
bias_np = np.fromfile(prefix / f"dense_{count+1}_b.bin", dtype=np.float32)
cout, cin = linear.weight.shape
linear.weight.data = torch.tensor(weight_np).reshape(cin, cout).T
linear.bias.data = torch.tensor(bias_np).reshape(linear.bias.shape)
count += 1
return self
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
class MiniEraPL(pl.LightningModule):
def __init__(self, dataset_path: PathLike):
super().__init__()
self.network = MiniERA()
self.dataset_path = Path(dataset_path)
def forward(self, image):
prediction = self.network(image)
return prediction
@staticmethod
def _get_loss(output, targets):
from torch.nn.functional import cross_entropy
return cross_entropy(output, targets)
@staticmethod
def _get_metric(output, targets):
predicted = torch.argmax(output, 1)
return (predicted == targets).sum().item() / len(targets)
def validation_step(self, val_batch, batch_idx):
images, target = val_batch
prediction = self(images)
loss = self._get_loss(prediction, target)
accuracy = self._get_metric(prediction, target)
self.log("val_loss", loss)
self.log("val_acc", accuracy)
return accuracy
def test_step(self, test_batch, batch_idx):
images, target = test_batch
prediction = self(images)
accuracy = self._get_metric(prediction, target)
self.log("test_acc", accuracy)
return accuracy
def test_dataloader(self):
dataset = CIFAR.from_file(
self.dataset_path / "input.bin", self.dataset_path / "labels.bin"
)
return DataLoader(dataset, batch_size=128)