From deb7dd1575e19e761a8c9d6bb1daefa40b2e610f Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Thu, 10 May 2018 11:04:41 +0300 Subject: [PATCH] pytorch 0.4: begin the effort of centralizing the knowledge about which device to use Eventually the application will pass a torch.device to the Scheduler. Now we just create a default device in the constructor, and then use it. --- distiller/scheduler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/distiller/scheduler.py b/distiller/scheduler.py index 6713601..bb61cb3 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -56,8 +56,9 @@ class CompressionScheduler(object): """Responsible for scheduling pruning and masking parameters. """ - def __init__(self, model): + def __init__(self, model, device=torch.device("cuda")): self.model = model + self.device = device self.policies = {} self.sched_metadata = {} @@ -103,7 +104,7 @@ class CompressionScheduler(object): def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss): # Last chance to compute the regularization loss, and optionally add it to the data loss - regularizer_loss = torch.tensor(0, dtype=torch.float, device='cuda') + regularizer_loss = torch.tensor(0, dtype=torch.float, device=self.device) if epoch in self.policies: for policy in self.policies[epoch]: -- GitLab