diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index 3aa1fe7ca8976fd9a1f64115ba3c2b0a6cabfd4c..abc1567a7a7211c67b5a87d95053747cd198be8e 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -838,7 +838,7 @@ def acts_quant_stats_collection(model, criterion, loggers, args): .format(args.qe_calibration)) model = distiller.utils.make_non_parallel_copy(model) args.effective_test_size = args.qe_calibration - test_loader = load_data(args, load_train=False, load_val=False) + test_loader = load_data(args, fixed_subset=True, load_train=False, load_val=False) test_fn = partial(test, test_loader=test_loader, criterion=criterion, loggers=loggers, args=args, activations_collectors=None) collect_quant_stats(model, test_fn, save_dir=msglogger.logdir, diff --git a/distiller/quantization/ptq_greedy_search.py b/distiller/quantization/ptq_greedy_search.py index 1f361febff5580555f179ee951626ff86602dd8d..2b5d96f381955872eec7db1dceba57a43a1d9bea 100644 --- a/distiller/quantization/ptq_greedy_search.py +++ b/distiller/quantization/ptq_greedy_search.py @@ -439,7 +439,7 @@ if __name__ == "__main__": # quant calibration dataloader: args.effective_test_size = args.qe_calib_portion args.batch_size = args.qe_calib_batchsize - calib_data_loader = classifier.load_data(args, load_train=False, load_val=False) + calib_data_loader = classifier.load_data(args, fixed_subset=True, load_train=False, load_val=False) # logging logging.getLogger().setLevel(logging.WARNING) msglogger = logging.getLogger(__name__)