Skip to content
Snippets Groups Projects
Commit 8cf7900d authored by Guy Jacob's avatar Guy Jacob
Browse files

SummaryGraph: Changes in adjacency_map and predecessors/successors

* Add op name and type to adjacency map
* Make module name de-norm optional in predecessors and successor
  functions (inc. in _f variants)
* More tests
parent 8e14ef0b
No related branches found
No related tags found
No related merge requests found
...@@ -73,7 +73,7 @@ class SummaryGraph(object): ...@@ -73,7 +73,7 @@ class SummaryGraph(object):
model_clone = distiller.make_non_parallel_copy(model) model_clone = distiller.make_non_parallel_copy(model)
with torch.onnx.set_training(model_clone, False): with torch.onnx.set_training(model_clone, False):
device = next(model_clone.parameters()).device device = distiller.model_device(model_clone)
dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device) dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True) trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True)
...@@ -81,6 +81,9 @@ class SummaryGraph(object): ...@@ -81,6 +81,9 @@ class SummaryGraph(object):
# Gemm nodes get the scope name of the last non-Gemm node that came before them. This can make # Gemm nodes get the scope name of the last non-Gemm node that came before them. This can make
# it impossible, in some cases, to derive the connectivity of the model using the original # it impossible, in some cases, to derive the connectivity of the model using the original
# module names. So we save the scope names for these nodes from the un-optimized trace. # module names. So we save the scope names for these nodes from the un-optimized trace.
#
# Note that if the node prior to the Gemm node isn't the result of a dedicated module call,
# then this issue doesn't occur. For simplicity we just track all Gemms.
aten_addmm_nodes_scope_names = [n.scopeName() for n in trace.graph().nodes() if n.kind() == 'aten::addmm'] aten_addmm_nodes_scope_names = [n.scopeName() for n in trace.graph().nodes() if n.kind() == 'aten::addmm']
onnx_gemm_count = 0 onnx_gemm_count = 0
...@@ -110,7 +113,6 @@ class SummaryGraph(object): ...@@ -110,7 +113,6 @@ class SummaryGraph(object):
# Convert the graph node's scope name to a PyTorch module name # Convert the graph node's scope name to a PyTorch module name
module_name = onnx_name_2_pytorch_name(new_op['orig-name']) module_name = onnx_name_2_pytorch_name(new_op['orig-name'])
new_op['module-name'] = module_name
if len(module_name) == 0: if len(module_name) == 0:
# Special case where the module name is an empty string - this happens # Special case where the module name is an empty string - this happens
# when the op is called from the "top-level" of the model # when the op is called from the "top-level" of the model
...@@ -118,6 +120,11 @@ class SummaryGraph(object): ...@@ -118,6 +120,11 @@ class SummaryGraph(object):
else: else:
new_op['name'] = module_name new_op['name'] = module_name
# Save the calling module name in the op dict. Denormalize it so it can
# be directly matched with the actual model
module_name = distiller.denormalize_module_name(self._src_model, module_name)
new_op['module-name'] = module_name
# The node's scope name in the graph corresponds to the module from which the op was called. # The node's scope name in the graph corresponds to the module from which the op was called.
# This means that when ops are invoked from the same module via functional calls or direct # This means that when ops are invoked from the same module via functional calls or direct
# operations on tensors, these ops will have the SAME MODEL NAME associated with them. # operations on tensors, these ops will have the SAME MODEL NAME associated with them.
...@@ -260,7 +267,8 @@ class SummaryGraph(object): ...@@ -260,7 +267,8 @@ class SummaryGraph(object):
ofm_vol = self.param_volume(conv_out) ofm_vol = self.param_volume(conv_out)
try: try:
# MACs = volume(OFM) * (#IFM * K^2) / #Groups # MACs = volume(OFM) * (#IFM * K^2) / #Groups
op['attrs']['MACs'] = int(ofm_vol * SummaryGraph.volume(conv_w) * self.params[conv_in]['shape'][1] / groups) op['attrs']['MACs'] = int(
ofm_vol * SummaryGraph.volume(conv_w) * self.params[conv_in]['shape'][1] / groups)
except IndexError: except IndexError:
# Todo: change the method for calculating MACs # Todo: change the method for calculating MACs
msglogger.error("An input to a Convolutional layer is missing shape information " msglogger.error("An input to a Convolutional layer is missing shape information "
...@@ -318,7 +326,7 @@ class SummaryGraph(object): ...@@ -318,7 +326,7 @@ class SummaryGraph(object):
def find_param(self, data_name): def find_param(self, data_name):
return self.params.get(data_name, None) return self.params.get(data_name, None)
def predecessors(self, node, depth, done_list=None): def predecessors(self, node, depth, done_list=None, denorm_names=True):
"""Returns a list of <op>'s predecessors""" """Returns a list of <op>'s predecessors"""
if done_list is None: if done_list is None:
done_list = [] done_list = []
...@@ -333,11 +341,13 @@ class SummaryGraph(object): ...@@ -333,11 +341,13 @@ class SummaryGraph(object):
else: else:
ret = [] ret = []
for predecessor in preds: for predecessor in preds:
ret += self.predecessors(predecessor, depth-1, done_list) ret += self.predecessors(predecessor, depth - 1, done_list, denorm_names)
return [distiller.denormalize_module_name(self._src_model, x) for x in ret] if denorm_names:
ret = [distiller.denormalize_module_name(self._src_model, x) for x in ret]
return ret
def predecessors_f(self, node_name, predecessors_types, done_list=None, logging=None): def predecessors_f(self, node_name, predecessors_types, done_list=None, logging=None, denorm_names=True):
"""Returns a list of <op>'s predecessors, if they match the <predecessors_types> criteria. """Returns a list of <op>'s predecessors, if they match the <predecessors_types> criteria.
""" """
node_name = distiller.normalize_module_name(node_name) node_name = distiller.normalize_module_name(node_name)
...@@ -362,7 +372,7 @@ class SummaryGraph(object): ...@@ -362,7 +372,7 @@ class SummaryGraph(object):
# We check if we found the type of node we're looking for, # We check if we found the type of node we're looking for,
# and that this is not the first node in our search. # and that this is not the first node in our search.
if node['type'] in predecessors_types and len(done_list) > 1: if node['type'] in predecessors_types and len(done_list) > 1:
return [distiller.denormalize_module_name(self._src_model, node_name)] return [distiller.denormalize_module_name(self._src_model, node_name) if denorm_names else node_name]
# This is an operation node # This is an operation node
preds = [edge.src for edge in self.edges if (edge.dst == node_name and preds = [edge.src for edge in self.edges if (edge.dst == node_name and
...@@ -373,11 +383,11 @@ class SummaryGraph(object): ...@@ -373,11 +383,11 @@ class SummaryGraph(object):
edge.src not in done_list)] edge.src not in done_list)]
ret = [] ret = []
for predecessor in preds: for predecessor in preds:
ret += self.predecessors_f(predecessor, predecessors_types, done_list, logging) ret += self.predecessors_f(predecessor, predecessors_types, done_list, logging, denorm_names)
return [distiller.denormalize_module_name(self._src_model, node) for node in ret] return ret
def successors(self, node, depth, done_list=None): def successors(self, node, depth, done_list=None, denorm_names=True):
"""Returns a list of <op>'s successors""" """Returns a list of <op>'s successors"""
if done_list is None: if done_list is None:
done_list = [] done_list = []
...@@ -392,11 +402,13 @@ class SummaryGraph(object): ...@@ -392,11 +402,13 @@ class SummaryGraph(object):
else: else:
ret = [] ret = []
for successor in succs: for successor in succs:
ret += self.successors(successor, depth-1, done_list) ret += self.successors(successor, depth - 1, done_list, denorm_names)
return [distiller.denormalize_module_name(self._src_model, x) for x in ret] if denorm_names:
ret = [distiller.denormalize_module_name(self._src_model, x) for x in ret]
return ret
def successors_f(self, node_name, successors_types, done_list=None, logging=None): def successors_f(self, node_name, successors_types, done_list=None, logging=None, denorm_names=True):
"""Returns a list of <op>'s successors, if they match the <successors_types> criteria. """Returns a list of <op>'s successors, if they match the <successors_types> criteria.
Traverse the graph, starting at node <node_name>, and search for successor Traverse the graph, starting at node <node_name>, and search for successor
...@@ -412,7 +424,7 @@ class SummaryGraph(object): ...@@ -412,7 +424,7 @@ class SummaryGraph(object):
node_is_an_op = False node_is_an_op = False
node = self.find_param(node_name) node = self.find_param(node_name)
if node is None: if node is None:
#raise ValueError("something went wrong") msglogger.warning("successors_f: Could not find node {}".format(node_name))
return [] return []
if done_list is None: if done_list is None:
...@@ -427,7 +439,7 @@ class SummaryGraph(object): ...@@ -427,7 +439,7 @@ class SummaryGraph(object):
# We check if we found the type of node we're looking for, # We check if we found the type of node we're looking for,
# and that this is not the first node in our search. # and that this is not the first node in our search.
if node['type'] in successors_types and len(done_list) > 1: if node['type'] in successors_types and len(done_list) > 1:
return [distiller.denormalize_module_name(self._src_model, node_name)] return [distiller.denormalize_module_name(self._src_model, node_name) if denorm_names else node_name]
# This is an operation node # This is an operation node
succs = [edge.dst for edge in self.edges if (edge.src == node_name and succs = [edge.dst for edge in self.edges if (edge.src == node_name and
...@@ -438,9 +450,9 @@ class SummaryGraph(object): ...@@ -438,9 +450,9 @@ class SummaryGraph(object):
edge.dst not in done_list)] edge.dst not in done_list)]
ret = [] ret = []
for successor in succs: for successor in succs:
ret += self.successors_f(successor, successors_types, done_list, logging) ret += self.successors_f(successor, successors_types, done_list, logging, denorm_names)
return [distiller.denormalize_module_name(self._src_model, node) for node in ret] return ret
def named_params_layers(self): def named_params_layers(self):
for param_name, param in self._src_model.named_parameters(): for param_name, param in self._src_model.named_parameters():
...@@ -466,30 +478,57 @@ class SummaryGraph(object): ...@@ -466,30 +478,57 @@ class SummaryGraph(object):
functional calls, such as "F.relu()", and tensor operations, such as "t3 = t1 + t2". functional calls, such as "F.relu()", and tensor operations, such as "t3 = t1 + t2".
""" """
adj_map = OrderedDict() adj_map = OrderedDict()
named_modules = OrderedDict(self._src_model.named_modules())
for op_name, op in self.ops.items(): for op_name, op in self.ops.items():
def dedicated_module_check(n): def dedicated_module_check(n):
module_name = self.ops[distiller.normalize_module_name(n)]['module-name'] if not dedicated_modules_only:
return len(self.module_ops_map[module_name]) == 1 or not dedicated_modules_only return True
module_name = self.ops[n]['module-name']
module = named_modules[module_name]
return len(self.module_ops_map[module_name]) == 1 and not distiller.has_children(module)
def op_meta(n):
return OpSimpleMetadata(distiller.denormalize_module_name(self._src_model, n), self.ops[n]['type'])
if not dedicated_module_check(op_name): if not dedicated_module_check(op_name):
continue continue
entry = AdjacentsEntry() entry = AdjacentsEntry(op_meta(op_name))
# Find the immediate preceding and succeeding modules. Depth of 1 gets us the # Find the immediate preceding and succeeding modules. Depth of 1 gets us the
# input and output tensors, depth of 2 gets the actual modules # input and output tensors, depth of 2 gets the actual modules
entry.predecessors = [n for n in self.predecessors(op, 2) if dedicated_module_check(n)] entry.predecessors = [op_meta(n) for n in self.predecessors(op, 2, denorm_names=False)
entry.successors = [n for n in self.successors(op, 2) if dedicated_module_check(n)] if dedicated_module_check(n)]
entry.successors = [op_meta(n) for n in self.successors(op, 2, denorm_names=False)
if dedicated_module_check(n)]
adj_map[distiller.denormalize_module_name(self._src_model, op_name)] = entry adj_map[entry.op_meta.name] = entry
return adj_map return adj_map
class OpSimpleMetadata(object):
def __init__(self, name, type):
self.name = name
self.type = type
def __repr__(self):
return "Op('{}' | {})".format(self.name, self.type)
def __eq__(self, other):
return self.name == other.name and self.type == other.type
class AdjacentsEntry(object): class AdjacentsEntry(object):
def __init__(self): def __init__(self, op_meta):
self.op_meta = op_meta
self.predecessors = [] self.predecessors = []
self.successors = [] self.successors = []
def __repr__(self): def __repr__(self):
return 'Predecessors: {0} ; Successors: {1}'.format(self.predecessors, self.successors) return 'OP: {0} ; PREDECESSORS: {1} ; SUCCESSORS: {2}'.format(self.op_meta, self.predecessors, self.successors)
def __eq__(self, other):
return self.op_meta == other.op_meta and \
self.predecessors == other.predecessors and \
self.successors == other.successors
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import logging import logging
import torch import torch
import torch.nn as nn
import pytest import pytest
import distiller import distiller
from distiller.models import ALL_MODEL_NAMES, create_model from distiller.models import ALL_MODEL_NAMES, create_model
...@@ -23,7 +24,7 @@ from distiller.apputils import * ...@@ -23,7 +24,7 @@ from distiller.apputils import *
from distiller import normalize_module_name, denormalize_module_name, \ from distiller import normalize_module_name, denormalize_module_name, \
SummaryGraph, onnx_name_2_pytorch_name SummaryGraph, onnx_name_2_pytorch_name
from distiller.model_summaries import connectivity_summary, connectivity_summary_verbose from distiller.model_summaries import connectivity_summary, connectivity_summary_verbose
from distiller.summary_graph import AdjacentsEntry, OpSimpleMetadata
# Logging configuration # Logging configuration
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
...@@ -32,9 +33,9 @@ logger = logging.getLogger() ...@@ -32,9 +33,9 @@ logger = logging.getLogger()
logger.addHandler(fh) logger.addHandler(fh)
def create_graph(dataset, arch): def create_graph(dataset, arch, parallel=False):
dummy_input = distiller.get_dummy_input(dataset) dummy_input = distiller.get_dummy_input(dataset)
model = create_model(False, dataset, arch, parallel=False) model = create_model(False, dataset, arch, parallel)
assert model is not None assert model is not None
return SummaryGraph(model, dummy_input) return SummaryGraph(model, dummy_input)
...@@ -44,10 +45,26 @@ def test_graph(): ...@@ -44,10 +45,26 @@ def test_graph():
assert g is not None assert g is not None
def test_connectivity(): @pytest.fixture(params=[False, True], ids=['sequential', 'parallel'])
g = create_graph('cifar10', 'resnet20_cifar') def parallel(request):
return request.param
@pytest.fixture(params=[True, False], ids=['denorm_name', 'norm_name'])
def denorm_names(request):
return request.param
def prefix_strs(str_list, prefix):
return [prefix + s for s in str_list]
def test_connectivity(parallel, denorm_names):
g = create_graph('cifar10', 'resnet20_cifar', parallel)
assert g is not None assert g is not None
prefix = 'module.' if parallel and denorm_names else ''
op_names = [op['name'] for op in g.ops.values()] op_names = [op['name'] for op in g.ops.values()]
assert len(op_names) == 80 assert len(op_names) == 80
...@@ -55,55 +72,57 @@ def test_connectivity(): ...@@ -55,55 +72,57 @@ def test_connectivity():
assert edges[0].src == '0' and edges[0].dst == 'conv1' assert edges[0].src == '0' and edges[0].dst == 'conv1'
# Test two sequential calls to predecessors (this was a bug once) # Test two sequential calls to predecessors (this was a bug once)
preds = g.predecessors(g.find_op('bn1'), 1) preds = g.predecessors(g.find_op('bn1'), 1, denorm_names=denorm_names)
preds = g.predecessors(g.find_op('bn1'), 1) preds = g.predecessors(g.find_op('bn1'), 1, denorm_names=denorm_names)
assert preds == ['129', '2', '3', '4', '5'] assert preds == ['129', '2', '3', '4', '5']
# Test successors # Test successors
succs = g.successors(g.find_op('bn1'), 2) succs = g.successors(g.find_op('bn1'), 2, denorm_names=denorm_names)
assert succs == ['relu'] assert succs == prefix_strs(['relu'], prefix)
op = g.find_op('layer1.0.relu2') op = g.find_op('layer1.0.relu2')
assert op is not None assert op is not None
succs = g.successors(op, 4) succs = g.successors(op, 4, denorm_names=denorm_names)
assert succs == ['layer1.1.bn1', 'layer1.1.relu2'] assert succs == prefix_strs(['layer1.1.bn1', 'layer1.1.relu2'], prefix)
preds = g.predecessors(g.find_op('bn1'), 10) preds = g.predecessors(g.find_op('bn1'), 10, denorm_names=denorm_names)
assert preds == [] assert preds == []
preds = g.predecessors(g.find_op('bn1'), 3) preds = g.predecessors(g.find_op('bn1'), 3, denorm_names=denorm_names)
assert preds == ['0', '1'] assert preds == ['0', '1']
def test_layer_search(): def test_layer_search(parallel, denorm_names):
g = create_graph('cifar10', 'resnet20_cifar') g = create_graph('cifar10', 'resnet20_cifar', parallel)
assert g is not None assert g is not None
prefix = 'module.' if parallel and denorm_names else ''
op = g.find_op('layer1.0.conv1') op = g.find_op('layer1.0.conv1')
assert op is not None assert op is not None
succs = g.successors_f('layer1.0.conv1', 'Conv', [], logging) succs = g.successors_f('layer1.0.conv1', 'Conv', [], logging, denorm_names=denorm_names)
assert ['layer1.0.conv2'] == succs assert succs == prefix_strs(['layer1.0.conv2'], prefix)
succs = g.successors_f('relu', 'Conv', [], logging) succs = g.successors_f('relu', 'Conv', [], logging, denorm_names=denorm_names)
assert succs == ['layer1.0.conv1', 'layer1.1.conv1', 'layer1.2.conv1', 'layer2.0.conv1', 'layer2.0.downsample.0'] assert succs == prefix_strs(['layer1.0.conv1', 'layer1.1.conv1', 'layer1.2.conv1', 'layer2.0.conv1',
'layer2.0.downsample.0'], prefix)
succs = g.successors_f('relu', 'Gemm', [], logging) succs = g.successors_f('relu', 'Gemm', [], logging, denorm_names=denorm_names)
assert succs == ['fc'] assert succs == prefix_strs(['fc'], prefix)
succs = g.successors_f('layer3.2', 'Conv', [], logging) succs = g.successors_f('layer3.2', 'Conv', [], logging, denorm_names=denorm_names)
assert succs == [] assert succs == []
#logging.debug(succs)
preds = g.predecessors_f('conv1', 'Conv', [], logging) preds = g.predecessors_f('conv1', 'Conv', [], logging, denorm_names=denorm_names)
assert preds == [] assert preds == []
preds = g.predecessors_f('layer1.0.conv2', 'Conv', [], logging) preds = g.predecessors_f('layer1.0.conv2', 'Conv', [], logging, denorm_names=denorm_names)
assert preds == ['layer1.0.conv1'] assert preds == prefix_strs(['layer1.0.conv1'], prefix)
preds = g.predecessors_f('layer1.0.conv1', 'Conv', [], logging) preds = g.predecessors_f('layer1.0.conv1', 'Conv', [], logging, denorm_names=denorm_names)
assert preds == ['conv1'] assert preds == prefix_strs(['conv1'], prefix)
preds = g.predecessors_f('layer1.1.conv1', 'Conv', [], logging) preds = g.predecessors_f('layer1.1.conv1', 'Conv', [], logging, denorm_names=denorm_names)
assert preds == ['layer1.0.conv2', 'conv1'] assert preds == prefix_strs(['layer1.0.conv2', 'conv1'], prefix)
def test_vgg(): def test_vgg():
...@@ -153,22 +172,18 @@ def test_normalize_module_name(): ...@@ -153,22 +172,18 @@ def test_normalize_module_name():
name_test('imagenet', 'alexnet') name_test('imagenet', 'alexnet')
def named_params_layers_test_aux(dataset, arch, dataparallel:bool): @pytest.mark.parametrize('dataset, arch', [('imagenet', 'vgg19'),
model = create_model(False, dataset, arch, parallel=dataparallel) ('cifar10', 'resnet20_cifar'),
('imagenet', 'alexnet'),
('imagenet', 'resnext101_32x4d')])
def test_named_params_layers(dataset, arch, parallel):
model = create_model(False, dataset, arch, parallel=parallel)
sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset)) sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset))
sgraph_layer_names = set(k for k, i, j in sgraph.named_params_layers()) sgraph_layer_names = set(k for k, i, j in sgraph.named_params_layers())
for layer_name in sgraph_layer_names: for layer_name in sgraph_layer_names:
assert sgraph.find_op(layer_name) is not None, '{} was not found in summary graph'.format(layer_name) assert sgraph.find_op(layer_name) is not None, '{} was not found in summary graph'.format(layer_name)
def test_named_params_layers():
for dataParallelModel in (True, False):
named_params_layers_test_aux('imagenet', 'vgg19', dataParallelModel)
named_params_layers_test_aux('cifar10', 'resnet20_cifar', dataParallelModel)
named_params_layers_test_aux('imagenet', 'alexnet', dataParallelModel)
named_params_layers_test_aux('imagenet', 'resnext101_32x4d', dataParallelModel)
def test_onnx_name_2_pytorch_name(): def test_onnx_name_2_pytorch_name():
assert onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu]") == "layer3.0.relu" assert onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu]") == "layer3.0.relu"
assert onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]') == "features.34" assert onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]') == "features.34"
...@@ -196,28 +211,145 @@ def test_sg_macs(): ...@@ -196,28 +211,145 @@ def test_sg_macs():
df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input('imagenet')) df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input('imagenet'))
modules_macs = df_compute.loc[:, ['Name', 'MACs']] modules_macs = df_compute.loc[:, ['Name', 'MACs']]
for name, mod in model.named_modules(): for name, mod in model.named_modules():
if isinstance(mod, (torch.nn.Conv2d, torch.nn.Linear)): if isinstance(mod, (nn.Conv2d, nn.Linear)):
summary_macs = int(modules_macs.loc[modules_macs.Name == name].MACs) summary_macs = int(modules_macs.loc[modules_macs.Name == name].MACs)
sg_macs = sg.find_op(name)['attrs']['MACs'] sg_macs = sg.find_op(name)['attrs']['MACs']
assert summary_macs == sg_macs assert summary_macs == sg_macs
def test_weights_size_attr(): @pytest.mark.parametrize('dataset, arch', [('cifar10', 'resnet20_cifar'),
def test(dataset, arch, dataparallel:bool): ('imagenet', 'alexnet'),
model = create_model(False, dataset, arch, parallel=dataparallel) ('imagenet', 'resnext101_32x4d')])
sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset)) def test_weights_size_attr(dataset, arch, parallel):
model = create_model(False, dataset, arch, parallel=parallel)
distiller.assign_layer_fq_names(model) sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset))
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Linear): distiller.assign_layer_fq_names(model)
op = sgraph.find_op(name) for name, mod in model.named_modules():
assert op is not None if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear):
assert op['attrs']['weights_vol'] == distiller.volume(mod.weight) op = sgraph.find_op(name)
assert op is not None
for data_parallel in (True, False): assert op['attrs']['weights_vol'] == distiller.volume(mod.weight)
test('cifar10', 'resnet20_cifar', data_parallel)
test('imagenet', 'alexnet', data_parallel)
test('imagenet', 'resnext101_32x4d', data_parallel) def test_merge_pad_avgpool():
class ModelWithAvgPool(nn.Module):
def __init__(self):
super(ModelWithAvgPool, self).__init__()
self.conv = nn.Conv2d(3, 10, 5)
self.avgpool = nn.AvgPool2d(2)
def forward(self, input):
return self.avgpool(self.conv(input))
m = ModelWithAvgPool()
sg = SummaryGraph(m, distiller.get_dummy_input(input_shape=(1, 3, 50, 50)))
avgpool_ops = [op_name for op_name in sg.ops if 'avgpool' in op_name]
assert len(avgpool_ops) == 1
assert sg.ops[avgpool_ops[0]]['name'] == 'avgpool'
assert sg.ops[avgpool_ops[0]]['type'] == 'AveragePool'
def test_gemm_nodes_scope_names():
class ModelWithGemms(nn.Module):
def __init__(self):
super(ModelWithGemms, self).__init__()
self.drop1 = nn.Dropout()
self.fc1 = nn.Linear(100, 50)
self.relu1 = nn.ReLU(inplace=True)
self.drop2 = nn.Dropout()
self.fc2 = nn.Linear(50, 25)
self.relu2 = nn.ReLU(inplace=True)
self.fc3 = nn.Linear(25, 1)
def forward(self, x):
# Isn't this pretty...
return self.fc3(self.relu2(self.fc2(self.drop2(self.relu1(self.fc1(self.drop1(x)))))))
m = ModelWithGemms()
sg = SummaryGraph(m, distiller.get_dummy_input(input_shape=(1, 100)))
# For the model above we expect the ops to be named (in order):
# 'drop1', 'fc1', 'relu1', 'drop2', 'fc2', 'relu2', 'fc3'
# But without our workaround in place, they'll be named:
# 'drop1', 'drop1__1', 'relu1', 'drop2', 'drop2__1', 'relu2', 'relu2__1'
# (that is - each FC node gets the name of the node before)
names, types = zip(*[(op_name, op['type']) for op_name, op in sg.ops.items()])
assert names == ('drop1', 'fc1', 'relu1', 'drop2', 'fc2', 'relu2', 'fc3')
assert types == ('Dropout', 'Gemm', 'Relu', 'Dropout', 'Gemm', 'Relu', 'Gemm')
@pytest.fixture(params=[False, True], ids=['dedicated_modules_off', 'dedicated_modules_on'])
def dedicated_modules(request):
return request.param
def test_adjacency_map(parallel, dedicated_modules):
class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
self.conv = nn.Conv2d(3, 10, 5)
self.bn = nn.BatchNorm2d(10)
self.relu = nn.ReLU()
def forward(self, x):
res = self.conv(x)
y = self.bn(res)
y = self.relu(y)
return y + res
def check_adj_entry(actual, expected):
assert actual.op_meta == expected.op_meta
assert actual.predecessors == expected.predecessors
assert actual.successors == expected.successors
prefix = 'module.' if parallel else ''
m = TestModel()
if parallel:
m = nn.DataParallel(m)
sg = SummaryGraph(m, distiller.get_dummy_input(input_shape=(1, 3, 10, 10)))
adj_map = sg.adjacency_map(dedicated_modules_only=dedicated_modules)
if dedicated_modules:
assert len(adj_map) == 3
else:
assert len(adj_map) == 4
conv_op_meta = OpSimpleMetadata(prefix + 'conv', 'Conv')
bn_op_meta = OpSimpleMetadata(prefix + 'bn', 'BatchNormalization')
relu_op_meta = OpSimpleMetadata(prefix + 'relu', 'Relu')
add_op_meta = OpSimpleMetadata('top_level_op', 'Add')
name = conv_op_meta.name
assert name in adj_map
expected = AdjacentsEntry(conv_op_meta)
expected.successors = [bn_op_meta] if dedicated_modules else [bn_op_meta, add_op_meta]
check_adj_entry(adj_map[name], expected)
name = bn_op_meta.name
assert name in adj_map
expected = AdjacentsEntry(bn_op_meta)
expected.predecessors = [conv_op_meta]
expected.successors = [relu_op_meta]
check_adj_entry(adj_map[name], expected)
name = relu_op_meta.name
assert name in adj_map
expected = AdjacentsEntry(relu_op_meta)
expected.predecessors = [bn_op_meta]
expected.successors = [] if dedicated_modules else [add_op_meta]
check_adj_entry(adj_map[name], expected)
name = add_op_meta.name
if dedicated_modules:
assert name not in adj_map
else:
assert name in adj_map
expected = AdjacentsEntry(add_op_meta)
expected.predecessors = [relu_op_meta, conv_op_meta]
check_adj_entry(adj_map[name], expected)
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment