From 81cb77d23b1b9b2fa163b02914ded81076009653 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Tue, 15 Jan 2019 23:55:35 +0200 Subject: [PATCH] Fix for CPU evaluation use-case Fix a mismatch between the location of the model and the computation. --- .../classifier_compression/compress_classifier.py | 1 + models/__init__.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 381a17c..0a82c11 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -297,6 +297,7 @@ def main(): if args.cpu or not torch.cuda.is_available(): # Set GPU index to -1 if using CPU args.device = 'cpu' + args.gpus = -1 else: args.device = 'cuda' if args.gpus is not None: diff --git a/models/__init__.py b/models/__init__.py index 193951d..93545ea 100755 --- a/models/__init__.py +++ b/models/__init__.py @@ -80,12 +80,13 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): print("FATAL ERROR: create_model does not support models for dataset %s" % dataset) exit() - if (arch.startswith('alexnet') or arch.startswith('vgg')) and parallel: - model.features = torch.nn.DataParallel(model.features, device_ids=device_ids) - elif parallel: - model = torch.nn.DataParallel(model, device_ids=device_ids) - if torch.cuda.is_available() and device_ids != -1: - model.cuda() + device = 'cuda' + if (arch.startswith('alexnet') or arch.startswith('vgg')) and parallel: + model.features = torch.nn.DataParallel(model.features, device_ids=device_ids) + elif parallel: + model = torch.nn.DataParallel(model, device_ids=device_ids) + else: + device = 'cpu' - return model + return model.to(device) -- GitLab