diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index c25c03f03fe07332f6b6ef8841a00b50d084b7f6..500eeb57b613690974b55239d0c6a3376746dc6a 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -161,6 +161,12 @@ class ClassifierCompressorSampleApp(classifier.ClassifierCompressor): def __init__(self, args, script_dir): super().__init__(args, script_dir) early_exit_init(args) + # Save the randomly-initialized model before training (useful for lottery-ticket method) + if args.save_untrained_model: + ckpt_name = '_'.join((self.args.name or "", "untrained")) + apputils.save_checkpoint(0, self.args.arch, self.model, + name=ckpt_name, dir=msglogger.logdir) + def handle_subapps(self): return handle_subapps(self.model, self.criterion, self.optimizer,