Skip to content
Snippets Groups Projects
Commit deb7dd15 authored by Neta Zmora's avatar Neta Zmora
Browse files

pytorch 0.4: begin the effort of centralizing the knowledge about which device to use

Eventually the application will pass a torch.device to the Scheduler.
Now we just create a default device in the constructor, and then use it.
parent 384f4740
No related branches found
No related tags found
No related merge requests found
......@@ -56,8 +56,9 @@ class CompressionScheduler(object):
"""Responsible for scheduling pruning and masking parameters.
"""
def __init__(self, model):
def __init__(self, model, device=torch.device("cuda")):
self.model = model
self.device = device
self.policies = {}
self.sched_metadata = {}
......@@ -103,7 +104,7 @@ class CompressionScheduler(object):
def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss):
# Last chance to compute the regularization loss, and optionally add it to the data loss
regularizer_loss = torch.tensor(0, dtype=torch.float, device='cuda')
regularizer_loss = torch.tensor(0, dtype=torch.float, device=self.device)
if epoch in self.policies:
for policy in self.policies[epoch]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment