Skip to content
Snippets Groups Projects
Unverified Commit 608af2b4 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Fix Issue #53 (#55)

When using a schedule with epochs that have nothing scheduled for them, apply_mask() is not invoked at the end of mini-batches, and pruned weights might be unmasked by the optimizer weight updates.

See explanation in issue #53 discussion
parent 14531cbf
No related branches found
No related tags found
No related merge requests found
...@@ -133,16 +133,16 @@ class CompressionScheduler(object): ...@@ -133,16 +133,16 @@ class CompressionScheduler(object):
return overall_loss return overall_loss
def on_minibatch_end(self, epoch, minibatch_id, minibatches_per_epoch, optimizer=None): def on_minibatch_end(self, epoch, minibatch_id, minibatches_per_epoch, optimizer=None):
# When we get to this point, the weights are no longer maksed. This is because during the backward # When we get to this point, the weights are no longer masked. This is because during the backward
# pass, the weights are updated. So we choose to lazily apply the pruning mask, only if some # pass, the weights may have been updated. This is true even when the gradients are zero, for some
# component is being called-back. # optimization algorithms such as SGD with momentum. See the Note in PyTorch's SGD documentation:
weights_are_masked = False # https://pytorch.org/docs/stable/optim.html#torch.optim.SGD.
#
# Therefore we choose to always apply the pruning mask. In the future we may optimize this by applying
# the mask only if the some policy is actually using the mask.
self.apply_mask()
if epoch in self.policies: if epoch in self.policies:
for policy in self.policies[epoch]: for policy in self.policies[epoch]:
if not weights_are_masked:
self.apply_mask()
weights_are_masked = True
policy.on_minibatch_end(self.model, epoch, minibatch_id, minibatches_per_epoch, policy.on_minibatch_end(self.model, epoch, minibatch_id, minibatches_per_epoch,
self.zeros_mask_dict, optimizer) self.zeros_mask_dict, optimizer)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment