diff --git a/distiller/thinning.py b/distiller/thinning.py index 2a89142401ed7f17351260fbea1dc6dfd4bd2f63..ab6cff157a1195cc068659eda6083cd37efd4f54 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -386,7 +386,7 @@ class FilterRemover(ScheduledTrainingPolicy): self.thinning_func(model, zeros_mask_dict, self.arch, self.dataset, optimizer=optimizer) self.done = True - def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, optimizer): + def on_minibatch_begin(self, model, epoch, minibatch_id, minibatches_per_epoch, zeros_mask_dict, meta, optimizer): # We hook onto the on_minibatch_begin because we want to run after the pruner which sparsified # the tensors. Pruners configure their pruning mask in on_epoch_begin, but apply the mask # only in on_minibatch_begin