diff --git a/distiller/scheduler.py b/distiller/scheduler.py index 6713601ebfadaafd76fcda6be81934d2dc53229c..bb61cb38449681895671b872db6271f96adaf158 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -56,8 +56,9 @@ class CompressionScheduler(object): """Responsible for scheduling pruning and masking parameters. """ - def __init__(self, model): + def __init__(self, model, device=torch.device("cuda")): self.model = model + self.device = device self.policies = {} self.sched_metadata = {} @@ -103,7 +104,7 @@ class CompressionScheduler(object): def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss): # Last chance to compute the regularization loss, and optionally add it to the data loss - regularizer_loss = torch.tensor(0, dtype=torch.float, device='cuda') + regularizer_loss = torch.tensor(0, dtype=torch.float, device=self.device) if epoch in self.policies: for policy in self.policies[epoch]: