diff --git a/distiller/utils.py b/distiller/utils.py index 9532a55feea604eade412a0a7e49c30bc93bbbb7..084eb56069e8a3c8bcda90e7a4732b9f8a575dd4 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -39,8 +39,11 @@ msglogger = logging.getLogger() def model_device(model): """Determine the device the model is allocated on.""" # Source: https://discuss.pytorch.org/t/how-to-check-if-model-is-on-cuda/180 - if next(model.parameters()).is_cuda: - return 'cuda' + try: + return str(next(model.parameters()).device) + except StopIteration: + # Model has no parameters + pass return 'cpu'