diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index a88ff2c812fa41ece4499637bb3b7aebc15a4526..f8761526650f6b5eec06405ca6a5820fd46d4c5b 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -73,7 +73,7 @@ class SummaryGraph(object): model_clone = distiller.make_non_parallel_copy(model) with torch.onnx.set_training(model_clone, False): - device = next(model_clone.parameters()).device + device = distiller.model_device(model_clone) dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device) trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True) @@ -81,6 +81,9 @@ class SummaryGraph(object): # Gemm nodes get the scope name of the last non-Gemm node that came before them. This can make # it impossible, in some cases, to derive the connectivity of the model using the original # module names. So we save the scope names for these nodes from the un-optimized trace. + # + # Note that if the node prior to the Gemm node isn't the result of a dedicated module call, + # then this issue doesn't occur. For simplicity we just track all Gemms. aten_addmm_nodes_scope_names = [n.scopeName() for n in trace.graph().nodes() if n.kind() == 'aten::addmm'] onnx_gemm_count = 0 @@ -110,7 +113,6 @@ class SummaryGraph(object): # Convert the graph node's scope name to a PyTorch module name module_name = onnx_name_2_pytorch_name(new_op['orig-name']) - new_op['module-name'] = module_name if len(module_name) == 0: # Special case where the module name is an empty string - this happens # when the op is called from the "top-level" of the model @@ -118,6 +120,11 @@ class SummaryGraph(object): else: new_op['name'] = module_name + # Save the calling module name in the op dict. Denormalize it so it can + # be directly matched with the actual model + module_name = distiller.denormalize_module_name(self._src_model, module_name) + new_op['module-name'] = module_name + # The node's scope name in the graph corresponds to the module from which the op was called. # This means that when ops are invoked from the same module via functional calls or direct # operations on tensors, these ops will have the SAME MODEL NAME associated with them. @@ -260,7 +267,8 @@ class SummaryGraph(object): ofm_vol = self.param_volume(conv_out) try: # MACs = volume(OFM) * (#IFM * K^2) / #Groups - op['attrs']['MACs'] = int(ofm_vol * SummaryGraph.volume(conv_w) * self.params[conv_in]['shape'][1] / groups) + op['attrs']['MACs'] = int( + ofm_vol * SummaryGraph.volume(conv_w) * self.params[conv_in]['shape'][1] / groups) except IndexError: # Todo: change the method for calculating MACs msglogger.error("An input to a Convolutional layer is missing shape information " @@ -318,7 +326,7 @@ class SummaryGraph(object): def find_param(self, data_name): return self.params.get(data_name, None) - def predecessors(self, node, depth, done_list=None): + def predecessors(self, node, depth, done_list=None, denorm_names=True): """Returns a list of <op>'s predecessors""" if done_list is None: done_list = [] @@ -333,11 +341,13 @@ class SummaryGraph(object): else: ret = [] for predecessor in preds: - ret += self.predecessors(predecessor, depth-1, done_list) + ret += self.predecessors(predecessor, depth - 1, done_list, denorm_names) - return [distiller.denormalize_module_name(self._src_model, x) for x in ret] + if denorm_names: + ret = [distiller.denormalize_module_name(self._src_model, x) for x in ret] + return ret - def predecessors_f(self, node_name, predecessors_types, done_list=None, logging=None): + def predecessors_f(self, node_name, predecessors_types, done_list=None, logging=None, denorm_names=True): """Returns a list of <op>'s predecessors, if they match the <predecessors_types> criteria. """ node_name = distiller.normalize_module_name(node_name) @@ -362,7 +372,7 @@ class SummaryGraph(object): # We check if we found the type of node we're looking for, # and that this is not the first node in our search. if node['type'] in predecessors_types and len(done_list) > 1: - return [distiller.denormalize_module_name(self._src_model, node_name)] + return [distiller.denormalize_module_name(self._src_model, node_name) if denorm_names else node_name] # This is an operation node preds = [edge.src for edge in self.edges if (edge.dst == node_name and @@ -373,11 +383,11 @@ class SummaryGraph(object): edge.src not in done_list)] ret = [] for predecessor in preds: - ret += self.predecessors_f(predecessor, predecessors_types, done_list, logging) + ret += self.predecessors_f(predecessor, predecessors_types, done_list, logging, denorm_names) - return [distiller.denormalize_module_name(self._src_model, node) for node in ret] + return ret - def successors(self, node, depth, done_list=None): + def successors(self, node, depth, done_list=None, denorm_names=True): """Returns a list of <op>'s successors""" if done_list is None: done_list = [] @@ -392,11 +402,13 @@ class SummaryGraph(object): else: ret = [] for successor in succs: - ret += self.successors(successor, depth-1, done_list) + ret += self.successors(successor, depth - 1, done_list, denorm_names) - return [distiller.denormalize_module_name(self._src_model, x) for x in ret] + if denorm_names: + ret = [distiller.denormalize_module_name(self._src_model, x) for x in ret] + return ret - def successors_f(self, node_name, successors_types, done_list=None, logging=None): + def successors_f(self, node_name, successors_types, done_list=None, logging=None, denorm_names=True): """Returns a list of <op>'s successors, if they match the <successors_types> criteria. Traverse the graph, starting at node <node_name>, and search for successor @@ -412,7 +424,7 @@ class SummaryGraph(object): node_is_an_op = False node = self.find_param(node_name) if node is None: - #raise ValueError("something went wrong") + msglogger.warning("successors_f: Could not find node {}".format(node_name)) return [] if done_list is None: @@ -427,7 +439,7 @@ class SummaryGraph(object): # We check if we found the type of node we're looking for, # and that this is not the first node in our search. if node['type'] in successors_types and len(done_list) > 1: - return [distiller.denormalize_module_name(self._src_model, node_name)] + return [distiller.denormalize_module_name(self._src_model, node_name) if denorm_names else node_name] # This is an operation node succs = [edge.dst for edge in self.edges if (edge.src == node_name and @@ -438,9 +450,9 @@ class SummaryGraph(object): edge.dst not in done_list)] ret = [] for successor in succs: - ret += self.successors_f(successor, successors_types, done_list, logging) + ret += self.successors_f(successor, successors_types, done_list, logging, denorm_names) - return [distiller.denormalize_module_name(self._src_model, node) for node in ret] + return ret def named_params_layers(self): for param_name, param in self._src_model.named_parameters(): @@ -466,30 +478,57 @@ class SummaryGraph(object): functional calls, such as "F.relu()", and tensor operations, such as "t3 = t1 + t2". """ adj_map = OrderedDict() + named_modules = OrderedDict(self._src_model.named_modules()) for op_name, op in self.ops.items(): def dedicated_module_check(n): - module_name = self.ops[distiller.normalize_module_name(n)]['module-name'] - return len(self.module_ops_map[module_name]) == 1 or not dedicated_modules_only + if not dedicated_modules_only: + return True + module_name = self.ops[n]['module-name'] + module = named_modules[module_name] + return len(self.module_ops_map[module_name]) == 1 and not distiller.has_children(module) + + def op_meta(n): + return OpSimpleMetadata(distiller.denormalize_module_name(self._src_model, n), self.ops[n]['type']) if not dedicated_module_check(op_name): continue - entry = AdjacentsEntry() + entry = AdjacentsEntry(op_meta(op_name)) # Find the immediate preceding and succeeding modules. Depth of 1 gets us the # input and output tensors, depth of 2 gets the actual modules - entry.predecessors = [n for n in self.predecessors(op, 2) if dedicated_module_check(n)] - entry.successors = [n for n in self.successors(op, 2) if dedicated_module_check(n)] + entry.predecessors = [op_meta(n) for n in self.predecessors(op, 2, denorm_names=False) + if dedicated_module_check(n)] + entry.successors = [op_meta(n) for n in self.successors(op, 2, denorm_names=False) + if dedicated_module_check(n)] - adj_map[distiller.denormalize_module_name(self._src_model, op_name)] = entry + adj_map[entry.op_meta.name] = entry return adj_map +class OpSimpleMetadata(object): + def __init__(self, name, type): + self.name = name + self.type = type + + def __repr__(self): + return "Op('{}' | {})".format(self.name, self.type) + + def __eq__(self, other): + return self.name == other.name and self.type == other.type + + class AdjacentsEntry(object): - def __init__(self): + def __init__(self, op_meta): + self.op_meta = op_meta self.predecessors = [] self.successors = [] def __repr__(self): - return 'Predecessors: {0} ; Successors: {1}'.format(self.predecessors, self.successors) + return 'OP: {0} ; PREDECESSORS: {1} ; SUCCESSORS: {2}'.format(self.op_meta, self.predecessors, self.successors) + + def __eq__(self, other): + return self.op_meta == other.op_meta and \ + self.predecessors == other.predecessors and \ + self.successors == other.successors diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index 596fc54f0885507ba998bdb7746fdde21ec94987..0e767aaac97f2ea5b91ea821346404d721fc0893 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -16,6 +16,7 @@ import logging import torch +import torch.nn as nn import pytest import distiller from distiller.models import ALL_MODEL_NAMES, create_model @@ -23,7 +24,7 @@ from distiller.apputils import * from distiller import normalize_module_name, denormalize_module_name, \ SummaryGraph, onnx_name_2_pytorch_name from distiller.model_summaries import connectivity_summary, connectivity_summary_verbose - +from distiller.summary_graph import AdjacentsEntry, OpSimpleMetadata # Logging configuration logging.basicConfig(level=logging.DEBUG) @@ -32,9 +33,9 @@ logger = logging.getLogger() logger.addHandler(fh) -def create_graph(dataset, arch): +def create_graph(dataset, arch, parallel=False): dummy_input = distiller.get_dummy_input(dataset) - model = create_model(False, dataset, arch, parallel=False) + model = create_model(False, dataset, arch, parallel) assert model is not None return SummaryGraph(model, dummy_input) @@ -44,10 +45,26 @@ def test_graph(): assert g is not None -def test_connectivity(): - g = create_graph('cifar10', 'resnet20_cifar') +@pytest.fixture(params=[False, True], ids=['sequential', 'parallel']) +def parallel(request): + return request.param + + +@pytest.fixture(params=[True, False], ids=['denorm_name', 'norm_name']) +def denorm_names(request): + return request.param + + +def prefix_strs(str_list, prefix): + return [prefix + s for s in str_list] + + +def test_connectivity(parallel, denorm_names): + g = create_graph('cifar10', 'resnet20_cifar', parallel) assert g is not None + prefix = 'module.' if parallel and denorm_names else '' + op_names = [op['name'] for op in g.ops.values()] assert len(op_names) == 80 @@ -55,55 +72,57 @@ def test_connectivity(): assert edges[0].src == '0' and edges[0].dst == 'conv1' # Test two sequential calls to predecessors (this was a bug once) - preds = g.predecessors(g.find_op('bn1'), 1) - preds = g.predecessors(g.find_op('bn1'), 1) + preds = g.predecessors(g.find_op('bn1'), 1, denorm_names=denorm_names) + preds = g.predecessors(g.find_op('bn1'), 1, denorm_names=denorm_names) assert preds == ['129', '2', '3', '4', '5'] # Test successors - succs = g.successors(g.find_op('bn1'), 2) - assert succs == ['relu'] + succs = g.successors(g.find_op('bn1'), 2, denorm_names=denorm_names) + assert succs == prefix_strs(['relu'], prefix) op = g.find_op('layer1.0.relu2') assert op is not None - succs = g.successors(op, 4) - assert succs == ['layer1.1.bn1', 'layer1.1.relu2'] + succs = g.successors(op, 4, denorm_names=denorm_names) + assert succs == prefix_strs(['layer1.1.bn1', 'layer1.1.relu2'], prefix) - preds = g.predecessors(g.find_op('bn1'), 10) + preds = g.predecessors(g.find_op('bn1'), 10, denorm_names=denorm_names) assert preds == [] - preds = g.predecessors(g.find_op('bn1'), 3) + preds = g.predecessors(g.find_op('bn1'), 3, denorm_names=denorm_names) assert preds == ['0', '1'] -def test_layer_search(): - g = create_graph('cifar10', 'resnet20_cifar') +def test_layer_search(parallel, denorm_names): + g = create_graph('cifar10', 'resnet20_cifar', parallel) assert g is not None + prefix = 'module.' if parallel and denorm_names else '' + op = g.find_op('layer1.0.conv1') assert op is not None - succs = g.successors_f('layer1.0.conv1', 'Conv', [], logging) - assert ['layer1.0.conv2'] == succs + succs = g.successors_f('layer1.0.conv1', 'Conv', [], logging, denorm_names=denorm_names) + assert succs == prefix_strs(['layer1.0.conv2'], prefix) - succs = g.successors_f('relu', 'Conv', [], logging) - assert succs == ['layer1.0.conv1', 'layer1.1.conv1', 'layer1.2.conv1', 'layer2.0.conv1', 'layer2.0.downsample.0'] + succs = g.successors_f('relu', 'Conv', [], logging, denorm_names=denorm_names) + assert succs == prefix_strs(['layer1.0.conv1', 'layer1.1.conv1', 'layer1.2.conv1', 'layer2.0.conv1', + 'layer2.0.downsample.0'], prefix) - succs = g.successors_f('relu', 'Gemm', [], logging) - assert succs == ['fc'] + succs = g.successors_f('relu', 'Gemm', [], logging, denorm_names=denorm_names) + assert succs == prefix_strs(['fc'], prefix) - succs = g.successors_f('layer3.2', 'Conv', [], logging) + succs = g.successors_f('layer3.2', 'Conv', [], logging, denorm_names=denorm_names) assert succs == [] - #logging.debug(succs) - preds = g.predecessors_f('conv1', 'Conv', [], logging) + preds = g.predecessors_f('conv1', 'Conv', [], logging, denorm_names=denorm_names) assert preds == [] - preds = g.predecessors_f('layer1.0.conv2', 'Conv', [], logging) - assert preds == ['layer1.0.conv1'] + preds = g.predecessors_f('layer1.0.conv2', 'Conv', [], logging, denorm_names=denorm_names) + assert preds == prefix_strs(['layer1.0.conv1'], prefix) - preds = g.predecessors_f('layer1.0.conv1', 'Conv', [], logging) - assert preds == ['conv1'] + preds = g.predecessors_f('layer1.0.conv1', 'Conv', [], logging, denorm_names=denorm_names) + assert preds == prefix_strs(['conv1'], prefix) - preds = g.predecessors_f('layer1.1.conv1', 'Conv', [], logging) - assert preds == ['layer1.0.conv2', 'conv1'] + preds = g.predecessors_f('layer1.1.conv1', 'Conv', [], logging, denorm_names=denorm_names) + assert preds == prefix_strs(['layer1.0.conv2', 'conv1'], prefix) def test_vgg(): @@ -153,22 +172,18 @@ def test_normalize_module_name(): name_test('imagenet', 'alexnet') -def named_params_layers_test_aux(dataset, arch, dataparallel:bool): - model = create_model(False, dataset, arch, parallel=dataparallel) +@pytest.mark.parametrize('dataset, arch', [('imagenet', 'vgg19'), + ('cifar10', 'resnet20_cifar'), + ('imagenet', 'alexnet'), + ('imagenet', 'resnext101_32x4d')]) +def test_named_params_layers(dataset, arch, parallel): + model = create_model(False, dataset, arch, parallel=parallel) sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset)) sgraph_layer_names = set(k for k, i, j in sgraph.named_params_layers()) for layer_name in sgraph_layer_names: assert sgraph.find_op(layer_name) is not None, '{} was not found in summary graph'.format(layer_name) -def test_named_params_layers(): - for dataParallelModel in (True, False): - named_params_layers_test_aux('imagenet', 'vgg19', dataParallelModel) - named_params_layers_test_aux('cifar10', 'resnet20_cifar', dataParallelModel) - named_params_layers_test_aux('imagenet', 'alexnet', dataParallelModel) - named_params_layers_test_aux('imagenet', 'resnext101_32x4d', dataParallelModel) - - def test_onnx_name_2_pytorch_name(): assert onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu]") == "layer3.0.relu" assert onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]') == "features.34" @@ -196,28 +211,145 @@ def test_sg_macs(): df_compute = distiller.model_performance_summary(model, distiller.get_dummy_input('imagenet')) modules_macs = df_compute.loc[:, ['Name', 'MACs']] for name, mod in model.named_modules(): - if isinstance(mod, (torch.nn.Conv2d, torch.nn.Linear)): + if isinstance(mod, (nn.Conv2d, nn.Linear)): summary_macs = int(modules_macs.loc[modules_macs.Name == name].MACs) sg_macs = sg.find_op(name)['attrs']['MACs'] assert summary_macs == sg_macs - - -def test_weights_size_attr(): - def test(dataset, arch, dataparallel:bool): - model = create_model(False, dataset, arch, parallel=dataparallel) - sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset)) - - distiller.assign_layer_fq_names(model) - for name, mod in model.named_modules(): - if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Linear): - op = sgraph.find_op(name) - assert op is not None - assert op['attrs']['weights_vol'] == distiller.volume(mod.weight) - - for data_parallel in (True, False): - test('cifar10', 'resnet20_cifar', data_parallel) - test('imagenet', 'alexnet', data_parallel) - test('imagenet', 'resnext101_32x4d', data_parallel) + + +@pytest.mark.parametrize('dataset, arch', [('cifar10', 'resnet20_cifar'), + ('imagenet', 'alexnet'), + ('imagenet', 'resnext101_32x4d')]) +def test_weights_size_attr(dataset, arch, parallel): + model = create_model(False, dataset, arch, parallel=parallel) + sgraph = SummaryGraph(model, distiller.get_dummy_input(dataset)) + + distiller.assign_layer_fq_names(model) + for name, mod in model.named_modules(): + if isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Linear): + op = sgraph.find_op(name) + assert op is not None + assert op['attrs']['weights_vol'] == distiller.volume(mod.weight) + + +def test_merge_pad_avgpool(): + class ModelWithAvgPool(nn.Module): + def __init__(self): + super(ModelWithAvgPool, self).__init__() + self.conv = nn.Conv2d(3, 10, 5) + self.avgpool = nn.AvgPool2d(2) + + def forward(self, input): + return self.avgpool(self.conv(input)) + + m = ModelWithAvgPool() + sg = SummaryGraph(m, distiller.get_dummy_input(input_shape=(1, 3, 50, 50))) + + avgpool_ops = [op_name for op_name in sg.ops if 'avgpool' in op_name] + assert len(avgpool_ops) == 1 + assert sg.ops[avgpool_ops[0]]['name'] == 'avgpool' + assert sg.ops[avgpool_ops[0]]['type'] == 'AveragePool' + + +def test_gemm_nodes_scope_names(): + class ModelWithGemms(nn.Module): + def __init__(self): + super(ModelWithGemms, self).__init__() + self.drop1 = nn.Dropout() + self.fc1 = nn.Linear(100, 50) + self.relu1 = nn.ReLU(inplace=True) + self.drop2 = nn.Dropout() + self.fc2 = nn.Linear(50, 25) + self.relu2 = nn.ReLU(inplace=True) + self.fc3 = nn.Linear(25, 1) + + def forward(self, x): + # Isn't this pretty... + return self.fc3(self.relu2(self.fc2(self.drop2(self.relu1(self.fc1(self.drop1(x))))))) + + m = ModelWithGemms() + sg = SummaryGraph(m, distiller.get_dummy_input(input_shape=(1, 100))) + + # For the model above we expect the ops to be named (in order): + # 'drop1', 'fc1', 'relu1', 'drop2', 'fc2', 'relu2', 'fc3' + # But without our workaround in place, they'll be named: + # 'drop1', 'drop1__1', 'relu1', 'drop2', 'drop2__1', 'relu2', 'relu2__1' + # (that is - each FC node gets the name of the node before) + names, types = zip(*[(op_name, op['type']) for op_name, op in sg.ops.items()]) + assert names == ('drop1', 'fc1', 'relu1', 'drop2', 'fc2', 'relu2', 'fc3') + assert types == ('Dropout', 'Gemm', 'Relu', 'Dropout', 'Gemm', 'Relu', 'Gemm') + + +@pytest.fixture(params=[False, True], ids=['dedicated_modules_off', 'dedicated_modules_on']) +def dedicated_modules(request): + return request.param + + +def test_adjacency_map(parallel, dedicated_modules): + class TestModel(nn.Module): + def __init__(self): + super(TestModel, self).__init__() + self.conv = nn.Conv2d(3, 10, 5) + self.bn = nn.BatchNorm2d(10) + self.relu = nn.ReLU() + + def forward(self, x): + res = self.conv(x) + y = self.bn(res) + y = self.relu(y) + return y + res + + def check_adj_entry(actual, expected): + assert actual.op_meta == expected.op_meta + assert actual.predecessors == expected.predecessors + assert actual.successors == expected.successors + + prefix = 'module.' if parallel else '' + + m = TestModel() + if parallel: + m = nn.DataParallel(m) + sg = SummaryGraph(m, distiller.get_dummy_input(input_shape=(1, 3, 10, 10))) + adj_map = sg.adjacency_map(dedicated_modules_only=dedicated_modules) + + if dedicated_modules: + assert len(adj_map) == 3 + else: + assert len(adj_map) == 4 + + conv_op_meta = OpSimpleMetadata(prefix + 'conv', 'Conv') + bn_op_meta = OpSimpleMetadata(prefix + 'bn', 'BatchNormalization') + relu_op_meta = OpSimpleMetadata(prefix + 'relu', 'Relu') + add_op_meta = OpSimpleMetadata('top_level_op', 'Add') + + name = conv_op_meta.name + assert name in adj_map + expected = AdjacentsEntry(conv_op_meta) + expected.successors = [bn_op_meta] if dedicated_modules else [bn_op_meta, add_op_meta] + check_adj_entry(adj_map[name], expected) + + name = bn_op_meta.name + assert name in adj_map + expected = AdjacentsEntry(bn_op_meta) + expected.predecessors = [conv_op_meta] + expected.successors = [relu_op_meta] + check_adj_entry(adj_map[name], expected) + + name = relu_op_meta.name + assert name in adj_map + expected = AdjacentsEntry(relu_op_meta) + expected.predecessors = [bn_op_meta] + expected.successors = [] if dedicated_modules else [add_op_meta] + check_adj_entry(adj_map[name], expected) + + name = add_op_meta.name + if dedicated_modules: + assert name not in adj_map + else: + assert name in adj_map + expected = AdjacentsEntry(add_op_meta) + expected.predecessors = [relu_op_meta, conv_op_meta] + check_adj_entry(adj_map[name], expected) if __name__ == '__main__':