diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index 4e7e3e76e312d230469e9971c23ebe181f9ca12d..b008746e45f781aa0837528538b99710280b8276 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -92,7 +92,7 @@ class ClassifierCompressor(object): loggers=[self.tflogger, self.pylogger], args=self.args) if verbose: distiller.log_weights_sparsity(self.model, epoch, [self.tflogger, self.pylogger]) - distiller.log_activation_statsitics(epoch, "train", loggers=[self.tflogger], + distiller.log_activation_statistics(epoch, "train", loggers=[self.tflogger], collector=collectors["sparsity"]) if self.args.masks_sparsity: msglogger.info(distiller.masks_sparsity_tbl_summary(self.model, @@ -118,7 +118,7 @@ class ClassifierCompressor(object): with collectors_context(self.activations_collectors["valid"]) as collectors: top1, top5, vloss = validate(self.val_loader, self.model, self.criterion, [self.pylogger], self.args, epoch) - distiller.log_activation_statsitics(epoch, "valid", loggers=[self.tflogger], + distiller.log_activation_statistics(epoch, "valid", loggers=[self.tflogger], collector=collectors["sparsity"]) save_collectors_data(collectors, msglogger.logdir) @@ -617,7 +617,7 @@ def test(test_loader, model, criterion, loggers, activations_collectors, args): activations_collectors = create_activation_stats_collectors(model, None) with collectors_context(activations_collectors["test"]) as collectors: top1, top5, lossses = _validate(test_loader, model, criterion, loggers, args) - distiller.log_activation_statsitics(-1, "test", loggers, collector=collectors['sparsity']) + distiller.log_activation_statistics(-1, "test", loggers, collector=collectors['sparsity']) save_collectors_data(collectors, msglogger.logdir) return top1, top5, lossses diff --git a/distiller/data_loggers/logger.py b/distiller/data_loggers/logger.py index 516208bed48c248269d5f4729db7e31dfd1ff306..bc99a0d5344c93babc0f71356a672a44a0bde1b9 100755 --- a/distiller/data_loggers/logger.py +++ b/distiller/data_loggers/logger.py @@ -52,7 +52,7 @@ class DataLogger(object): def log_training_progress(self, stats_dict, epoch, completed, total, freq): pass - def log_activation_statsitic(self, phase, stat_name, activation_stats, epoch): + def log_activation_statistic(self, phase, stat_name, activation_stats, epoch): pass def log_weights_sparsity(self, model, epoch): @@ -83,7 +83,7 @@ class PythonLogger(DataLogger): log = log + '{name} {val:.6f} '.format(name=name, val=val) self.pylogger.info(log) - def log_activation_statsitic(self, phase, stat_name, activation_stats, epoch): + def log_activation_statistic(self, phase, stat_name, activation_stats, epoch): data = [] for layer, statistic in activation_stats.items(): data.append([layer, statistic]) @@ -146,7 +146,7 @@ class TensorBoardLogger(DataLogger): self.tblogger.scalar_summary(prefix+tag, value, total_steps(total, epoch, completed)) self.tblogger.sync_to_file() - def log_activation_statsitic(self, phase, stat_name, activation_stats, epoch): + def log_activation_statistic(self, phase, stat_name, activation_stats, epoch): group = stat_name + '/activations/' + phase + "/" for tag, value in activation_stats.items(): self.tblogger.scalar_summary(group+tag, value, epoch) diff --git a/distiller/utils.py b/distiller/utils.py index 1b55f2529f7b30c523d2ef17b7b7e42f68b10fd7..ac0158e2868cf59ce31ddbd14ffea6c169e31520 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -528,12 +528,12 @@ def log_training_progress(stats_dict, params_dict, epoch, steps_completed, total logger.log_weights_distribution(params_dict, steps_completed) -def log_activation_statsitics(epoch, phase, loggers, collector): +def log_activation_statistics(epoch, phase, loggers, collector): """Log information about the sparsity of the activations""" if collector is None: return for logger in loggers: - logger.log_activation_statsitic(phase, collector.stat_name, collector.value(), epoch) + logger.log_activation_statistic(phase, collector.stat_name, collector.value(), epoch) def log_weights_sparsity(model, epoch, loggers):