From deb7dd1575e19e761a8c9d6bb1daefa40b2e610f Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 10 May 2018 11:04:41 +0300
Subject: [PATCH] pytorch 0.4: begin the effort of centralizing the knowledge
 about which device to use

Eventually the application will pass a torch.device to the Scheduler.
Now we just create a default device in the constructor, and then use it.
---
 distiller/scheduler.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index 6713601..bb61cb3 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -56,8 +56,9 @@ class CompressionScheduler(object):
     """Responsible for scheduling pruning and masking parameters.
 
     """
-    def __init__(self, model):
+    def __init__(self, model, device=torch.device("cuda")):
         self.model = model
+        self.device = device
         self.policies = {}
         self.sched_metadata = {}
 
@@ -103,7 +104,7 @@ class CompressionScheduler(object):
 
     def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss):
         # Last chance to compute the regularization loss, and optionally add it to the data loss
-        regularizer_loss = torch.tensor(0, dtype=torch.float, device='cuda')
+        regularizer_loss = torch.tensor(0, dtype=torch.float, device=self.device)
 
         if epoch in self.policies:
             for policy in self.policies[epoch]:
-- 
GitLab