Newer
Older
"""Make PyTorch checkpoint of MiniERA model from legacy HPVM weights."""
import site
from pathlib import Path
import torch
self_folder = Path(__file__).parent.absolute()
site.addsitedir(self_folder)
from torch_dnn import CIFAR, MiniERA
@torch.no_grad()
def main():
prefix = self_folder / "assets/miniera"
model = MiniERA().load_legacy_hpvm_weights(prefix)
# Test mini ERA
dataset = CIFAR.from_file(prefix / "input.bin", prefix / "labels.bin")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
correct = 0
total = 0
for data in dataloader:
images, labels = data[0], data[1]
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy of the network on the test images: {100 * correct / total} %")
torch.save(model.state_dict(), prefix / "miniera.pth")
if __name__ == "__main__":
main()