Skip to content
Snippets Groups Projects
Commit deb7dd15 authored by Neta Zmora's avatar Neta Zmora
Browse files

pytorch 0.4: begin the effort of centralizing the knowledge about which device to use

Eventually the application will pass a torch.device to the Scheduler.
Now we just create a default device in the constructor, and then use it.
parent 384f4740
No related branches found
No related tags found
No related merge requests found
...@@ -56,8 +56,9 @@ class CompressionScheduler(object): ...@@ -56,8 +56,9 @@ class CompressionScheduler(object):
"""Responsible for scheduling pruning and masking parameters. """Responsible for scheduling pruning and masking parameters.
""" """
def __init__(self, model): def __init__(self, model, device=torch.device("cuda")):
self.model = model self.model = model
self.device = device
self.policies = {} self.policies = {}
self.sched_metadata = {} self.sched_metadata = {}
...@@ -103,7 +104,7 @@ class CompressionScheduler(object): ...@@ -103,7 +104,7 @@ class CompressionScheduler(object):
def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss): def before_backward_pass(self, epoch, minibatch_id, minibatches_per_epoch, loss):
# Last chance to compute the regularization loss, and optionally add it to the data loss # Last chance to compute the regularization loss, and optionally add it to the data loss
regularizer_loss = torch.tensor(0, dtype=torch.float, device='cuda') regularizer_loss = torch.tensor(0, dtype=torch.float, device=self.device)
if epoch in self.policies: if epoch in self.policies:
for policy in self.policies[epoch]: for policy in self.policies[epoch]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment