From 81cb77d23b1b9b2fa163b02914ded81076009653 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Tue, 15 Jan 2019 23:55:35 +0200
Subject: [PATCH] Fix for CPU evaluation use-case

Fix a mismatch between the location of the model and the
computation.
---
 .../classifier_compression/compress_classifier.py |  1 +
 models/__init__.py                                | 15 ++++++++-------
 2 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 381a17c..0a82c11 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -297,6 +297,7 @@ def main():
     if args.cpu or not torch.cuda.is_available():
         # Set GPU index to -1 if using CPU
         args.device = 'cpu'
+        args.gpus = -1
     else:
         args.device = 'cuda'
         if args.gpus is not None:
diff --git a/models/__init__.py b/models/__init__.py
index 193951d..93545ea 100755
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -80,12 +80,13 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
         print("FATAL ERROR: create_model does not support models for dataset %s" % dataset)
         exit()
 
-    if (arch.startswith('alexnet') or arch.startswith('vgg')) and parallel:
-        model.features = torch.nn.DataParallel(model.features, device_ids=device_ids)
-    elif parallel:
-        model = torch.nn.DataParallel(model, device_ids=device_ids)
-
     if torch.cuda.is_available() and device_ids != -1:
-        model.cuda()
+        device = 'cuda'
+        if (arch.startswith('alexnet') or arch.startswith('vgg')) and parallel:
+            model.features = torch.nn.DataParallel(model.features, device_ids=device_ids)
+        elif parallel:
+            model = torch.nn.DataParallel(model, device_ids=device_ids)
+    else:
+        device = 'cpu'
 
-    return model
+    return model.to(device)
-- 
GitLab