From 89b17434fc5d39d2e3138a330839c23539c7ff07 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Mon, 8 Jul 2019 14:07:22 +0300
Subject: [PATCH] More generic implementation of distiller.utils.model_device

---
 distiller/utils.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/distiller/utils.py b/distiller/utils.py
index 9532a55..084eb56 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -39,8 +39,11 @@ msglogger = logging.getLogger()
 def model_device(model):
     """Determine the device the model is allocated on."""
     # Source: https://discuss.pytorch.org/t/how-to-check-if-model-is-on-cuda/180
-    if next(model.parameters()).is_cuda:
-        return 'cuda'
+    try:
+        return str(next(model.parameters()).device)
+    except StopIteration:
+        # Model has no parameters
+        pass
     return 'cpu'
 
 
-- 
GitLab