From bf1e6a0d45d62e0d526cc0f0d3a378b283171ce2 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Thu, 16 May 2019 14:00:36 +0300 Subject: [PATCH] Refactoring: utils.get_dummy_input() Remove the multiple instances of code that generates dummy input per dataset. --- distiller/thinning.py | 11 ++--------- distiller/utils.py | 8 +++++++- tests/common.py | 8 -------- tests/test_model_summary.py | 4 ++-- tests/test_pruning.py | 2 +- tests/test_summarygraph.py | 18 ++++-------------- 6 files changed, 16 insertions(+), 35 deletions(-) diff --git a/distiller/thinning.py b/distiller/thinning.py index 0ca0e24..1cd7bd9 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -29,8 +29,8 @@ import math import logging from collections import namedtuple import torch -from .policy import ScheduledTrainingPolicy import distiller +from .policy import ScheduledTrainingPolicy from .summary_graph import SummaryGraph msglogger = logging.getLogger(__name__) @@ -63,14 +63,7 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', def create_graph(dataset, model): - dummy_input = None - if dataset == 'imagenet': - dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False) - elif dataset == 'cifar10': - dummy_input = torch.randn((1, 3, 32, 32), requires_grad=False) - assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset) - - dummy_input = dummy_input.to(distiller.model_device(model)) + dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model)) return SummaryGraph(model, dummy_input) diff --git a/distiller/utils.py b/distiller/utils.py index 3f8b825..580fa51 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -556,13 +556,19 @@ def has_children(module): return False -def get_dummy_input(dataset): +def get_dummy_input(dataset, device=None): + """Generate a representative dummy (random) input for the specified dataset. + + If a device is specified, then the dummay_input is moved to that device. + """ if dataset == 'imagenet': dummy_input = torch.randn(1, 3, 224, 224) elif dataset == 'cifar10': dummy_input = torch.randn(1, 3, 32, 32) else: raise ValueError("dataset %s is not supported" % dataset) + if device: + dummy_input = dummy_input.to(device) return dummy_input diff --git a/tests/common.py b/tests/common.py index 792e033..728d62f 100755 --- a/tests/common.py +++ b/tests/common.py @@ -37,13 +37,5 @@ def find_module_by_name(model, module_to_find): return None -def get_dummy_input(dataset): - if dataset == "imagenet": - return torch.randn(1, 3, 224, 224).cuda() - elif dataset == "cifar10": - return torch.randn(1, 3, 32, 32).cuda() - raise ValueError("Trying to use an unknown dataset " + dataset) - - def almost_equal(a , b, max_diff=0.000001): return abs(a - b) <= max_diff diff --git a/tests/test_model_summary.py b/tests/test_model_summary.py index b15badc..89c81e6 100755 --- a/tests/test_model_summary.py +++ b/tests/test_model_summary.py @@ -42,7 +42,7 @@ def test_compute_summary(): dataset = "cifar10" arch = "simplenet_cifar" model, _ = common.setup_test(arch, dataset, parallel=True) - df_compute = distiller.model_performance_summary(model, common.get_dummy_input(dataset)) + df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input(dataset)) module_macs = df_compute.loc[:, 'MACs'].to_list() # [conv1, conv2, fc1, fc2, fc3] assert module_macs == [352800, 240000, 48000, 10080, 840] @@ -50,7 +50,7 @@ def test_compute_summary(): dataset = "imagenet" arch = "mobilenet" model, _ = common.setup_test(arch, dataset, parallel=True) - df_compute = distiller.model_performance_summary(model, common.get_dummy_input(dataset)) + df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input(dataset)) module_macs = df_compute.loc[:, 'MACs'].to_list() expected_macs = [10838016, 3612672, 25690112, 1806336, 25690112, 3612672, 51380224, 903168, 25690112, 1806336, 51380224, 451584, 25690112, 903168, 51380224, 903168, diff --git a/tests/test_pruning.py b/tests/test_pruning.py index 42a6ab7..ff78d1e 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -277,7 +277,7 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel): assert bn1.bias.size(0) == cnt_nnz_channels assert bn1.weight.size(0) == cnt_nnz_channels - dummy_input = common.get_dummy_input(config.dataset) + dummy_input = distiller.get_dummy_input(config.dataset, distiller.model_device(model)) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1) run_forward_backward(model, optimizer, dummy_input) diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index 46fedb3..55ce88b 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -32,18 +32,8 @@ logger = logging.getLogger() logger.addHandler(fh) -def get_input(dataset): - if dataset == 'imagenet': - return torch.randn((1, 3, 224, 224), requires_grad=False) - elif dataset == 'cifar10': - return torch.randn((1, 3, 32, 32)) - return None - - def create_graph(dataset, arch): - dummy_input = get_input(dataset) - assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset) - + dummy_input = distiller.get_dummy_input(dataset) model = create_model(False, dataset, arch, parallel=False) assert model is not None return SummaryGraph(model, dummy_input) @@ -163,7 +153,7 @@ def test_normalize_module_name(): def named_params_layers_test_aux(dataset, arch, dataparallel:bool): model = create_model(False, dataset, arch, parallel=dataparallel) - sgraph = SummaryGraph(model, get_input(dataset)) + sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset)) sgraph_layer_names = set(k for k, i, j in sgraph.named_params_layers()) for layer_name in sgraph_layer_names: assert sgraph.find_op(layer_name) is not None, '{} was not found in summary graph'.format(layer_name) @@ -202,7 +192,7 @@ def test_sg_macs(): sg = create_graph('imagenet', 'mobilenet') assert sg model, _ = common.setup_test('mobilenet', 'imagenet', parallel=False) - df_compute = distiller.model_performance_summary(model, common.get_dummy_input('imagenet')) + df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input('imagenet')) modules_macs = df_compute.loc[:, ['Name', 'MACs']] for name, mod in model.named_modules(): if isinstance(mod, (torch.nn.Conv2d, torch.nn.Linear)): @@ -214,7 +204,7 @@ def test_sg_macs(): def test_weights_size_attr(): def test(dataset, arch, dataparallel:bool): model = create_model(False, dataset, arch, parallel=dataparallel) - sgraph = SummaryGraph(model, get_input(dataset)) + sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset)) distiller.assign_layer_fq_names(model) for name, mod in model.named_modules(): -- GitLab