Skip to content
Snippets Groups Projects
Commit bf1e6a0d authored by Neta Zmora's avatar Neta Zmora
Browse files

Refactoring: utils.get_dummy_input()

Remove the multiple instances of code that generates
dummy input per dataset.
parent af5c7219
No related branches found
No related tags found
No related merge requests found
...@@ -29,8 +29,8 @@ import math ...@@ -29,8 +29,8 @@ import math
import logging import logging
from collections import namedtuple from collections import namedtuple
import torch import torch
from .policy import ScheduledTrainingPolicy
import distiller import distiller
from .policy import ScheduledTrainingPolicy
from .summary_graph import SummaryGraph from .summary_graph import SummaryGraph
msglogger = logging.getLogger(__name__) msglogger = logging.getLogger(__name__)
...@@ -63,14 +63,7 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', ...@@ -63,14 +63,7 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers',
def create_graph(dataset, model): def create_graph(dataset, model):
dummy_input = None dummy_input = distiller.get_dummy_input(dataset, distiller.model_device(model))
if dataset == 'imagenet':
dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False)
elif dataset == 'cifar10':
dummy_input = torch.randn((1, 3, 32, 32), requires_grad=False)
assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)
dummy_input = dummy_input.to(distiller.model_device(model))
return SummaryGraph(model, dummy_input) return SummaryGraph(model, dummy_input)
......
...@@ -556,13 +556,19 @@ def has_children(module): ...@@ -556,13 +556,19 @@ def has_children(module):
return False return False
def get_dummy_input(dataset): def get_dummy_input(dataset, device=None):
"""Generate a representative dummy (random) input for the specified dataset.
If a device is specified, then the dummay_input is moved to that device.
"""
if dataset == 'imagenet': if dataset == 'imagenet':
dummy_input = torch.randn(1, 3, 224, 224) dummy_input = torch.randn(1, 3, 224, 224)
elif dataset == 'cifar10': elif dataset == 'cifar10':
dummy_input = torch.randn(1, 3, 32, 32) dummy_input = torch.randn(1, 3, 32, 32)
else: else:
raise ValueError("dataset %s is not supported" % dataset) raise ValueError("dataset %s is not supported" % dataset)
if device:
dummy_input = dummy_input.to(device)
return dummy_input return dummy_input
......
...@@ -37,13 +37,5 @@ def find_module_by_name(model, module_to_find): ...@@ -37,13 +37,5 @@ def find_module_by_name(model, module_to_find):
return None return None
def get_dummy_input(dataset):
if dataset == "imagenet":
return torch.randn(1, 3, 224, 224).cuda()
elif dataset == "cifar10":
return torch.randn(1, 3, 32, 32).cuda()
raise ValueError("Trying to use an unknown dataset " + dataset)
def almost_equal(a , b, max_diff=0.000001): def almost_equal(a , b, max_diff=0.000001):
return abs(a - b) <= max_diff return abs(a - b) <= max_diff
...@@ -42,7 +42,7 @@ def test_compute_summary(): ...@@ -42,7 +42,7 @@ def test_compute_summary():
dataset = "cifar10" dataset = "cifar10"
arch = "simplenet_cifar" arch = "simplenet_cifar"
model, _ = common.setup_test(arch, dataset, parallel=True) model, _ = common.setup_test(arch, dataset, parallel=True)
df_compute = distiller.model_performance_summary(model, common.get_dummy_input(dataset)) df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input(dataset))
module_macs = df_compute.loc[:, 'MACs'].to_list() module_macs = df_compute.loc[:, 'MACs'].to_list()
# [conv1, conv2, fc1, fc2, fc3] # [conv1, conv2, fc1, fc2, fc3]
assert module_macs == [352800, 240000, 48000, 10080, 840] assert module_macs == [352800, 240000, 48000, 10080, 840]
...@@ -50,7 +50,7 @@ def test_compute_summary(): ...@@ -50,7 +50,7 @@ def test_compute_summary():
dataset = "imagenet" dataset = "imagenet"
arch = "mobilenet" arch = "mobilenet"
model, _ = common.setup_test(arch, dataset, parallel=True) model, _ = common.setup_test(arch, dataset, parallel=True)
df_compute = distiller.model_performance_summary(model, common.get_dummy_input(dataset)) df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input(dataset))
module_macs = df_compute.loc[:, 'MACs'].to_list() module_macs = df_compute.loc[:, 'MACs'].to_list()
expected_macs = [10838016, 3612672, 25690112, 1806336, 25690112, 3612672, 51380224, 903168, expected_macs = [10838016, 3612672, 25690112, 1806336, 25690112, 3612672, 51380224, 903168,
25690112, 1806336, 51380224, 451584, 25690112, 903168, 51380224, 903168, 25690112, 1806336, 51380224, 451584, 25690112, 903168, 51380224, 903168,
......
...@@ -277,7 +277,7 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel): ...@@ -277,7 +277,7 @@ def arbitrary_channel_pruning(config, channels_to_remove, is_parallel):
assert bn1.bias.size(0) == cnt_nnz_channels assert bn1.bias.size(0) == cnt_nnz_channels
assert bn1.weight.size(0) == cnt_nnz_channels assert bn1.weight.size(0) == cnt_nnz_channels
dummy_input = common.get_dummy_input(config.dataset) dummy_input = distiller.get_dummy_input(config.dataset, distiller.model_device(model))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.1)
run_forward_backward(model, optimizer, dummy_input) run_forward_backward(model, optimizer, dummy_input)
......
...@@ -32,18 +32,8 @@ logger = logging.getLogger() ...@@ -32,18 +32,8 @@ logger = logging.getLogger()
logger.addHandler(fh) logger.addHandler(fh)
def get_input(dataset):
if dataset == 'imagenet':
return torch.randn((1, 3, 224, 224), requires_grad=False)
elif dataset == 'cifar10':
return torch.randn((1, 3, 32, 32))
return None
def create_graph(dataset, arch): def create_graph(dataset, arch):
dummy_input = get_input(dataset) dummy_input = distiller.get_dummy_input(dataset)
assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)
model = create_model(False, dataset, arch, parallel=False) model = create_model(False, dataset, arch, parallel=False)
assert model is not None assert model is not None
return SummaryGraph(model, dummy_input) return SummaryGraph(model, dummy_input)
...@@ -163,7 +153,7 @@ def test_normalize_module_name(): ...@@ -163,7 +153,7 @@ def test_normalize_module_name():
def named_params_layers_test_aux(dataset, arch, dataparallel:bool): def named_params_layers_test_aux(dataset, arch, dataparallel:bool):
model = create_model(False, dataset, arch, parallel=dataparallel) model = create_model(False, dataset, arch, parallel=dataparallel)
sgraph = SummaryGraph(model, get_input(dataset)) sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset))
sgraph_layer_names = set(k for k, i, j in sgraph.named_params_layers()) sgraph_layer_names = set(k for k, i, j in sgraph.named_params_layers())
for layer_name in sgraph_layer_names: for layer_name in sgraph_layer_names:
assert sgraph.find_op(layer_name) is not None, '{} was not found in summary graph'.format(layer_name) assert sgraph.find_op(layer_name) is not None, '{} was not found in summary graph'.format(layer_name)
...@@ -202,7 +192,7 @@ def test_sg_macs(): ...@@ -202,7 +192,7 @@ def test_sg_macs():
sg = create_graph('imagenet', 'mobilenet') sg = create_graph('imagenet', 'mobilenet')
assert sg assert sg
model, _ = common.setup_test('mobilenet', 'imagenet', parallel=False) model, _ = common.setup_test('mobilenet', 'imagenet', parallel=False)
df_compute = distiller.model_performance_summary(model, common.get_dummy_input('imagenet')) df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input('imagenet'))
modules_macs = df_compute.loc[:, ['Name', 'MACs']] modules_macs = df_compute.loc[:, ['Name', 'MACs']]
for name, mod in model.named_modules(): for name, mod in model.named_modules():
if isinstance(mod, (torch.nn.Conv2d, torch.nn.Linear)): if isinstance(mod, (torch.nn.Conv2d, torch.nn.Linear)):
...@@ -214,7 +204,7 @@ def test_sg_macs(): ...@@ -214,7 +204,7 @@ def test_sg_macs():
def test_weights_size_attr(): def test_weights_size_attr():
def test(dataset, arch, dataparallel:bool): def test(dataset, arch, dataparallel:bool):
model = create_model(False, dataset, arch, parallel=dataparallel) model = create_model(False, dataset, arch, parallel=dataparallel)
sgraph = SummaryGraph(model, get_input(dataset)) sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset))
distiller.assign_layer_fq_names(model) distiller.assign_layer_fq_names(model)
for name, mod in model.named_modules(): for name, mod in model.named_modules():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment