"""Make PyTorch checkpoint of MiniERA model from legacy HPVM weights.""" import site from pathlib import Path import torch self_folder = Path(__file__).parent.absolute() site.addsitedir(self_folder) from torch_dnn import CIFAR, MiniERA @torch.no_grad() def main(): prefix = self_folder / "assets/miniera" model = MiniERA().load_legacy_hpvm_weights(prefix) # Test mini ERA dataset = CIFAR.from_file(prefix / "input.bin", prefix / "labels.bin") dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) correct = 0 total = 0 for data in dataloader: images, labels = data[0], data[1] outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f"Accuracy of the network on the test images: {100 * correct / total} %") torch.save(model.state_dict(), prefix / "miniera.pth") if __name__ == "__main__": main()