diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index 6713601ebfadaafd76fcda6be81934d2dc53229c..bb61cb38449681895671b872db6271f96adaf158 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -56,8 +56,9 @@ class CompressionScheduler(object):
     """Responsible for scheduling pruning and masking parameters.
 
     """
-    def __init__(self, model):
+    def __init__(self, model, device=torch.device("cuda")):
         self.model = model
+        self.device = device
         self.policies = {}
         self.sched_metadata = {}
 
@@ -103,7 +104,7 @@ class CompressionScheduler(object):
 
     def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss):
         # Last chance to compute the regularization loss, and optionally add it to the data loss
-        regularizer_loss = torch.tensor(0, dtype=torch.float, device='cuda')
+        regularizer_loss = torch.tensor(0, dtype=torch.float, device=self.device)
 
         if epoch in self.policies:
             for policy in self.policies[epoch]: