diff --git a/distiller/thinning.py b/distiller/thinning.py index 0ca0e2499ce749dc8e8e65a37689f0c5a4b3941a..1cd7bd9d3500d1747e907d4a7929bbbcfab06e20 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -29,8 +29,8 @@ import math import logging from collections import namedtuple import torch -from .policy import ScheduledTrainingPolicy import distiller +from .policy import ScheduledTrainingPolicy from .summary_graph import SummaryGraph msglogger = logging.getLogger(__name__) @@ -63,14 +63,7 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', def create_graph(dataset, model): - dummy_input = None - if dataset == 'imagenet': - dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False) - elif dataset == 'cifar10': - dummy_input = torch.randn((1, 3, 32, 32), requires_grad=False) - assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset) - - dummy_input = dummy_input.to(distiller.model_device(model)) + dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model)) return SummaryGraph(model, dummy_input) diff --git a/distiller/utils.py b/distiller/utils.py index 3f8b825416424a11791a6647c8b2d22e1296c47f..580fa518ccb1b6ed5880fc99dbe113d148c2a2e4 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -556,13 +556,19 @@ def has_children(module): return False -def get_dummy_input(dataset): +def get_dummy_input(dataset, device=None): + """Generate a representative dummy (random) input for the specified dataset. + + If a device is specified, then the dummay_input is moved to that device. + """ if dataset == 'imagenet': dummy_input = torch.randn(1, 3, 224, 224) elif dataset == 'cifar10': dummy_input = torch.randn(1, 3, 32, 32) else: raise ValueError("dataset %s is not supported" % dataset) + if device: + dummy_input = dummy_input.to(device) return dummy_input diff --git a/tests/common.py b/tests/common.py index 792e03343c8617f647026420f418f61092e4ddfa..728d62f01591f7ad61e199ceaf7b96f92446ccfa 100755 --- a/tests/common.py +++ b/tests/common.py @@ -37,13 +37,5 @@ def find_module_by_name(model, module_to_find): return None -def get_dummy_input(dataset): - if dataset == "imagenet": - return torch.randn(1, 3, 224, 224).cuda() - elif dataset == "cifar10": - return torch.randn(1, 3, 32, 32).cuda() - raise ValueError("Trying to use an unknown dataset " + dataset) - - def almost_equal(a , b, max_diff=0.000001): return abs(a - b) <= max_diff diff --git a/tests/test_model_summary.py b/tests/test_model_summary.py index b15badc5ffa7cc5010b9341552422344ae7c4aea..89c81e690303f839e4c0e3280369699ee913635e 100755 --- a/tests/test_model_summary.py +++ b/tests/test_model_summary.py @@ -42,7 +42,7 @@ def test_compute_summary(): dataset = "cifar10" arch = "simplenet_cifar" model, _ = common.setup_test(arch, dataset, parallel=True) - df_compute = distiller.model_performance_summary(model, common.get_dummy_input(dataset)) + df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input(dataset)) module_macs = df_compute.loc[:, 'MACs'].to_list() # [conv1, conv2, fc1, fc2, fc3] assert module_macs == [352800, 240000, 48000, 10080, 840] @@ -50,7 +50,7 @@ def test_compute_summary(): dataset = "imagenet" arch = "mobilenet" model, _ = common.setup_test(arch, dataset, parallel=True) - df_compute = distiller.model_performance_summary(model, common.get_dummy_input(dataset)) + df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input(dataset)) module_macs = df_compute.loc[:, 'MACs'].to_list() expected_macs = [10838016, 3612672, 25690112, 1806336, 25690112, 3612672, 51380224, 903168, 25690112, 1806336, 51380224, 451584, 25690112, 903168, 51380224, 903168, diff --git a/tests/test_pruning.py b/tests/test_pruning.py index 42a6ab7b39ef7fdc911649a5a05ee5f9128c39a2..ff78d1e1c1a7315b4e40544bc7e67c6e7bfdc900 100755 --- a/tests/test_pruning.py +++ b/tests/test_pruning.py @@ -277,7 +277,7 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel): assert bn1.bias.size(0) == cnt_nnz_channels assert bn1.weight.size(0) == cnt_nnz_channels - dummy_input = common.get_dummy_input(config.dataset) + dummy_input = distiller.get_dummy_input(config.dataset, distiller.model_device(model)) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1) run_forward_backward(model, optimizer, dummy_input) diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index 46fedb3b6a99734b46d0bc801247e93502a1016d..55ce88b526ad121711dfaa7a1555e429cb694550 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -32,18 +32,8 @@ logger = logging.getLogger() logger.addHandler(fh) -def get_input(dataset): - if dataset == 'imagenet': - return torch.randn((1, 3, 224, 224), requires_grad=False) - elif dataset == 'cifar10': - return torch.randn((1, 3, 32, 32)) - return None - - def create_graph(dataset, arch): - dummy_input = get_input(dataset) - assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset) - + dummy_input = distiller.get_dummy_input(dataset) model = create_model(False, dataset, arch, parallel=False) assert model is not None return SummaryGraph(model, dummy_input) @@ -163,7 +153,7 @@ def test_normalize_module_name(): def named_params_layers_test_aux(dataset, arch, dataparallel:bool): model = create_model(False, dataset, arch, parallel=dataparallel) - sgraph = SummaryGraph(model, get_input(dataset)) + sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset)) sgraph_layer_names = set(k for k, i, j in sgraph.named_params_layers()) for layer_name in sgraph_layer_names: assert sgraph.find_op(layer_name) is not None, '{} was not found in summary graph'.format(layer_name) @@ -202,7 +192,7 @@ def test_sg_macs(): sg = create_graph('imagenet', 'mobilenet') assert sg model, _ = common.setup_test('mobilenet', 'imagenet', parallel=False) - df_compute = distiller.model_performance_summary(model, common.get_dummy_input('imagenet')) + df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input('imagenet')) modules_macs = df_compute.loc[:, ['Name', 'MACs']] for name, mod in model.named_modules(): if isinstance(mod, (torch.nn.Conv2d, torch.nn.Linear)): @@ -214,7 +204,7 @@ def test_sg_macs(): def test_weights_size_attr(): def test(dataset, arch, dataparallel:bool): model = create_model(False, dataset, arch, parallel=dataparallel) - sgraph = SummaryGraph(model, get_input(dataset)) + sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset)) distiller.assign_layer_fq_names(model) for name, mod in model.named_modules():