diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index d3facd54e9ff58360b9f89b57c69872bc29c0ff6..d246fd58dc13ea4c6e2f10399e015a719f830890 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -21,7 +21,6 @@ import collections import torch import torch.jit as jit import logging -from collections import OrderedDict msglogger = logging.getLogger() @@ -100,17 +99,17 @@ class SummaryGraph(object): device = next(model_clone.parameters()).device dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device) - trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True) + trace, _ = jit.get_trace_graph(model_clone, dummy_input) # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # composing a GEMM operation; etc. torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) graph = trace.graph() - self.ops = OrderedDict() - self.params = OrderedDict() + self.ops = {} + self.params = {} self.edges = [] - self.temp = OrderedDict() + self.temp = {} in_out = list(graph.inputs()) + list(graph.outputs()) for param in in_out: @@ -149,7 +148,7 @@ class SummaryGraph(object): self.__add_output(new_op, output) self.edges.append(SummaryGraph.Edge(new_op['name'], output.uniqueName())) - new_op['attrs'] = OrderedDict([(attr_name, node[attr_name]) for attr_name in node.attributeNames()]) + new_op['attrs'] = {attr_name: node[attr_name] for attr_name in node.attributeNames()} self.add_macs_attr() self.add_footprint_attr() @@ -157,7 +156,7 @@ class SummaryGraph(object): del model_clone def __create_op(self, onnx_node): - op = OrderedDict() + op = {} op['name'] = onnx_node.scopeName() op['orig-name'] = onnx_node.scopeName() op['type'] = onnx_node.kind().lstrip('::onnx') @@ -189,7 +188,7 @@ class SummaryGraph(object): return param def __tensor_desc(self, n): - tensor = OrderedDict() + tensor = {} tensor['id'] = n.uniqueName() try: # try parsing the FM tensor type. For example: Float(1, 64, 8, 8) @@ -246,8 +245,14 @@ class SummaryGraph(object): ofm_vol = self.param_volume(conv_out) ifm_vol = self.param_volume(conv_in) if op['type'] == 'Conv' or op['type'] == 'Gemm': - conv_w = op['inputs'][1] - weights_vol = self.param_volume(conv_w) + if op['type'] == 'Conv': + kernel_size = self.volume(op['attrs']['kernel_shape']) + group = op['attrs']['group'] + else: + kernel_size, group = 1, 1 + n_ifm = self.param_shape(conv_in)[1] / group + n_ofm = self.param_shape(conv_out)[1] + weights_vol = kernel_size * n_ifm * n_ofm op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol op['attrs']['fm_vol'] = ofm_vol + ifm_vol op['attrs']['weights_vol'] = weights_vol diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index 8cf1e7dd2632f3c238d95de87e95afa4e5c454e3..f74bfdaafab546091b58a44fb39e54cbc7c43494 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -116,6 +116,24 @@ def test_layer_search(): assert preds == ['layer1.0.conv2', 'conv1'] +def test_weights_size_attr(): + def test(dataset, arch, dataparallel:bool): + model = create_model(False, dataset, arch, parallel=False) + sgraph = SummaryGraph(model, get_input(dataset)) + + distiller.assign_layer_fq_names(model) + for name, mod in model.named_modules(): + if isinstance(mod, torch.nn.Conv2d) or isinstance(mod, torch.nn.Linear): + op = sgraph.find_op(name) + assert op is not None + assert op['attrs']['weights_vol'] == distiller.volume(mod.weight) + + for data_parallel in (True, False): + test('cifar10', 'resnet20_cifar', data_parallel) + test('imagenet', 'alexnet', data_parallel) + test('imagenet', 'resnext101_32x4d', data_parallel) + + def test_vgg(): g = create_graph('imagenet', 'vgg19') assert g is not None @@ -170,11 +188,11 @@ def named_params_layers_test_aux(dataset, arch, dataparallel:bool): def test_named_params_layers(): - for dataParallelModel in (True, False): - named_params_layers_test_aux('imagenet', 'vgg19', dataParallelModel) - named_params_layers_test_aux('cifar10', 'resnet20_cifar', dataParallelModel) - named_params_layers_test_aux('imagenet', 'alexnet', dataParallelModel) - named_params_layers_test_aux('imagenet', 'resnext101_32x4d', dataParallelModel) + for data_parallel in (True, False): + named_params_layers_test_aux('imagenet', 'vgg19', data_parallel) + named_params_layers_test_aux('cifar10', 'resnet20_cifar', data_parallel) + named_params_layers_test_aux('imagenet', 'alexnet', data_parallel) + named_params_layers_test_aux('imagenet', 'resnext101_32x4d', data_parallel) def test_onnx_name_2_pytorch_name(): @@ -213,4 +231,4 @@ def test_sg_macs(): if __name__ == '__main__': #test_connectivity_summary() - test_sg_macs() \ No newline at end of file + test_sg_macs()