Skip to content
Snippets Groups Projects
make_ckpt.py 998 B
Newer Older
  • Learn to ignore specific revisions
  • """Make PyTorch checkpoint of MiniERA model from legacy HPVM weights."""
    import site
    from pathlib import Path
    
    import torch
    
    self_folder = Path(__file__).parent.absolute()
    site.addsitedir(self_folder)
    
    from torch_dnn import CIFAR, MiniERA
    
    
    @torch.no_grad()
    def main():
        prefix = self_folder / "assets/miniera"
        model = MiniERA().load_legacy_hpvm_weights(prefix)
        # Test mini ERA
        dataset = CIFAR.from_file(prefix / "input.bin", prefix / "labels.bin")
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
        correct = 0
        total = 0
        for data in dataloader:
            images, labels = data[0], data[1]
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(f"Accuracy of the network on the test images: {100 * correct / total} %")
        torch.save(model.state_dict(), prefix / "miniera.pth")
    
    
    if __name__ == "__main__":
        main()