From bf1e6a0d45d62e0d526cc0f0d3a378b283171ce2 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 16 May 2019 14:00:36 +0300
Subject: [PATCH] Refactoring: utils.get_dummy_input()

Remove the multiple instances of code that generates
dummy input per dataset.
---
 distiller/thinning.py       | 11 ++---------
 distiller/utils.py          |  8 +++++++-
 tests/common.py             |  8 --------
 tests/test_model_summary.py |  4 ++--
 tests/test_pruning.py       |  2 +-
 tests/test_summarygraph.py  | 18 ++++--------------
 6 files changed, 16 insertions(+), 35 deletions(-)

diff --git a/distiller/thinning.py b/distiller/thinning.py
index 0ca0e24..1cd7bd9 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -29,8 +29,8 @@ import math
 import logging
 from collections import namedtuple
 import torch
-from .policy import ScheduledTrainingPolicy
 import distiller
+from .policy import ScheduledTrainingPolicy
 from .summary_graph import SummaryGraph
 msglogger = logging.getLogger(__name__)
 
@@ -63,14 +63,7 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers',
 
 
 def create_graph(dataset, model):
-    dummy_input = None
-    if dataset == 'imagenet':
-        dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False)
-    elif dataset == 'cifar10':
-        dummy_input = torch.randn((1, 3, 32, 32), requires_grad=False)
-    assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)
-
-    dummy_input = dummy_input.to(distiller.model_device(model))
+    dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model))
     return SummaryGraph(model, dummy_input)
 
 
diff --git a/distiller/utils.py b/distiller/utils.py
index 3f8b825..580fa51 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -556,13 +556,19 @@ def has_children(module):
         return False
 
 
-def get_dummy_input(dataset):
+def get_dummy_input(dataset, device=None):
+    """Generate a representative dummy (random) input for the specified dataset.
+
+    If a device is specified, then the dummay_input is moved to that device.
+    """
     if dataset == 'imagenet':
         dummy_input = torch.randn(1, 3, 224, 224)
     elif dataset == 'cifar10':
         dummy_input = torch.randn(1, 3, 32, 32)
     else:
         raise ValueError("dataset %s is not supported" % dataset)
+    if device:
+        dummy_input = dummy_input.to(device)
     return dummy_input
 
 
diff --git a/tests/common.py b/tests/common.py
index 792e033..728d62f 100755
--- a/tests/common.py
+++ b/tests/common.py
@@ -37,13 +37,5 @@ def find_module_by_name(model, module_to_find):
     return None
 
 
-def get_dummy_input(dataset):
-    if dataset == "imagenet":
-        return torch.randn(1, 3, 224, 224).cuda()
-    elif dataset == "cifar10":
-        return torch.randn(1, 3, 32, 32).cuda()
-    raise ValueError("Trying to use an unknown dataset " + dataset)
-
-
 def almost_equal(a , b, max_diff=0.000001):
     return abs(a - b) <= max_diff
diff --git a/tests/test_model_summary.py b/tests/test_model_summary.py
index b15badc..89c81e6 100755
--- a/tests/test_model_summary.py
+++ b/tests/test_model_summary.py
@@ -42,7 +42,7 @@ def test_compute_summary():
     dataset = "cifar10"
     arch = "simplenet_cifar"
     model, _ = common.setup_test(arch, dataset, parallel=True)
-    df_compute = distiller.model_performance_summary(model, common.get_dummy_input(dataset))
+    df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input(dataset))
     module_macs = df_compute.loc[:, 'MACs'].to_list()
     #                     [conv1,  conv2,  fc1,   fc2,   fc3]
     assert module_macs == [352800, 240000, 48000, 10080, 840]
@@ -50,7 +50,7 @@ def test_compute_summary():
     dataset = "imagenet"
     arch = "mobilenet"
     model, _ = common.setup_test(arch, dataset, parallel=True)
-    df_compute = distiller.model_performance_summary(model, common.get_dummy_input(dataset))
+    df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input(dataset))
     module_macs = df_compute.loc[:, 'MACs'].to_list()
     expected_macs = [10838016, 3612672, 25690112, 1806336, 25690112, 3612672, 51380224, 903168, 
                      25690112, 1806336, 51380224, 451584, 25690112, 903168, 51380224, 903168, 
diff --git a/tests/test_pruning.py b/tests/test_pruning.py
index 42a6ab7..ff78d1e 100755
--- a/tests/test_pruning.py
+++ b/tests/test_pruning.py
@@ -277,7 +277,7 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel):
         assert bn1.bias.size(0) == cnt_nnz_channels
         assert bn1.weight.size(0) == cnt_nnz_channels
 
-    dummy_input = common.get_dummy_input(config.dataset)
+    dummy_input = distiller.get_dummy_input(config.dataset, distiller.model_device(model))
     optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1)
     run_forward_backward(model, optimizer, dummy_input)
 
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 46fedb3..55ce88b 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -32,18 +32,8 @@ logger = logging.getLogger()
 logger.addHandler(fh)
 
 
-def get_input(dataset):
-    if dataset == 'imagenet':
-        return torch.randn((1, 3, 224, 224), requires_grad=False)
-    elif dataset == 'cifar10':
-        return torch.randn((1, 3, 32, 32))
-    return None
-
-
 def create_graph(dataset, arch):
-    dummy_input = get_input(dataset)
-    assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)
-
+    dummy_input = distiller.get_dummy_input(dataset)
     model = create_model(False, dataset, arch, parallel=False)
     assert model is not None
     return SummaryGraph(model, dummy_input)
@@ -163,7 +153,7 @@ def test_normalize_module_name():
 
 def named_params_layers_test_aux(dataset, arch, dataparallel:bool):
     model = create_model(False, dataset, arch, parallel=dataparallel)
-    sgraph = SummaryGraph(model, get_input(dataset))
+    sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset))
     sgraph_layer_names = set(k for k, i, j in sgraph.named_params_layers())
     for layer_name in sgraph_layer_names:
         assert sgraph.find_op(layer_name) is not None, '{} was not found in summary graph'.format(layer_name)
@@ -202,7 +192,7 @@ def test_sg_macs():
     sg = create_graph('imagenet', 'mobilenet')
     assert sg
     model, _ = common.setup_test('mobilenet', 'imagenet', parallel=False)
-    df_compute = distiller.model_performance_summary(model, common.get_dummy_input('imagenet'))
+    df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input('imagenet'))
     modules_macs = df_compute.loc[:, ['Name', 'MACs']]
     for name, mod in model.named_modules():
         if isinstance(mod, (torch.nn.Conv2d, torch.nn.Linear)):
@@ -214,7 +204,7 @@ def test_sg_macs():
 def test_weights_size_attr():
     def test(dataset, arch, dataparallel:bool):
         model = create_model(False, dataset, arch, parallel=dataparallel)
-        sgraph = SummaryGraph(model, get_input(dataset))
+        sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset))
 
         distiller.assign_layer_fq_names(model)
         for name, mod in model.named_modules():
-- 
GitLab