From 89b17434fc5d39d2e3138a330839c23539c7ff07 Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Mon, 8 Jul 2019 14:07:22 +0300 Subject: [PATCH] More generic implementation of distiller.utils.model_device --- distiller/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/distiller/utils.py b/distiller/utils.py index 9532a55..084eb56 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -39,8 +39,11 @@ msglogger = logging.getLogger() def model_device(model): """Determine the device the model is allocated on.""" # Source: https://discuss.pytorch.org/t/how-to-check-if-model-is-on-cuda/180 - if next(model.parameters()).is_cuda: - return 'cuda' + try: + return str(next(model.parameters()).device) + except StopIteration: + # Model has no parameters + pass return 'cpu' -- GitLab