Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • llvm/distiller
1 result
Show changes
Commits on Source (1)
...@@ -557,7 +557,7 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): ...@@ -557,7 +557,7 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector):
if not tensor.is_contiguous(): if not tensor.is_contiguous():
tensor = tensor.contiguous() tensor = tensor.contiguous()
if (len(tensor.size()) == 0): if len(tensor.size()) == 0:
return return
act = tensor.view(tensor.size(0), -1) act = tensor.view(tensor.size(0), -1)
...@@ -846,7 +846,7 @@ class RawActivationsCollector(ActivationStatsCollector): ...@@ -846,7 +846,7 @@ class RawActivationsCollector(ActivationStatsCollector):
return dir_name return dir_name
def collect_quant_stats(model, trainer, dataloader, save_dir=None, classes=None, inplace_runtime_check=False, def collect_quant_stats(model, test_fn, save_dir=None, classes=None, inplace_runtime_check=False,
disable_inplace_attrs=False, inplace_attr_names=('inplace',), disable_inplace_attrs=False, inplace_attr_names=('inplace',),
modules_to_collect=None): modules_to_collect=None):
""" """
...@@ -877,13 +877,11 @@ def collect_quant_stats(model, trainer, dataloader, save_dir=None, classes=None, ...@@ -877,13 +877,11 @@ def collect_quant_stats(model, trainer, dataloader, save_dir=None, classes=None,
inplace_attr_names=inplace_attr_names) inplace_attr_names=inplace_attr_names)
with collector_context(quant_stats_collector, modules_to_collect): with collector_context(quant_stats_collector, modules_to_collect):
msglogger.info('Pass 1: Collecting min, max, avg_min, avg_max, mean') msglogger.info('Pass 1: Collecting min, max, avg_min, avg_max, mean')
trainer(model, dataloader) test_fn(model=model)
# trainer.test(model, test_dataloaders=dataloader)
# Collect Laplace distribution stats: # Collect Laplace distribution stats:
msglogger.info('Pass 2: Collecting b, std parameters') msglogger.info('Pass 2: Collecting b, std parameters')
quant_stats_collector.start_second_pass() quant_stats_collector.start_second_pass()
trainer(model, dataloader) test_fn(model=model)
# trainer.test(model, test_dataloaders=dataloader)
quant_stats_collector.stop_second_pass() quant_stats_collector.stop_second_pass()
msglogger.info('Stats collection complete') msglogger.info('Stats collection complete')
......