diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 381a17cf0e04dfd37d06167cf5d4ce5c569c14e9..0a82c11628e30f0348bb33a26c1e14dd48406301 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -297,6 +297,7 @@ def main(): if args.cpu or not torch.cuda.is_available(): # Set GPU index to -1 if using CPU args.device = 'cpu' + args.gpus = -1 else: args.device = 'cuda' if args.gpus is not None: diff --git a/models/__init__.py b/models/__init__.py index 193951d579dd36ac663998ad2387e41986a56109..93545ea56c0de532d441eb4bbb86efcbbe160e58 100755 --- a/models/__init__.py +++ b/models/__init__.py @@ -80,12 +80,13 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): print("FATAL ERROR: create_model does not support models for dataset %s" % dataset) exit() - if (arch.startswith('alexnet') or arch.startswith('vgg')) and parallel: - model.features = torch.nn.DataParallel(model.features, device_ids=device_ids) - elif parallel: - model = torch.nn.DataParallel(model, device_ids=device_ids) - if torch.cuda.is_available() and device_ids != -1: - model.cuda() + device = 'cuda' + if (arch.startswith('alexnet') or arch.startswith('vgg')) and parallel: + model.features = torch.nn.DataParallel(model.features, device_ids=device_ids) + elif parallel: + model = torch.nn.DataParallel(model, device_ids=device_ids) + else: + device = 'cpu' - return model + return model.to(device)