From 7b3ab5ef7c03d98698df3410ce57f7c90553537b Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Sun, 15 Jul 2018 11:44:27 +0300 Subject: [PATCH] apputils/model_summaries.py: cleanup PEP8 warnings Also add a warnning when swe can't find a node whose predecessors we're looking for. --- apputils/model_summaries.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py index ec6a1e1..1faa8be 100755 --- a/apputils/model_summaries.py +++ b/apputils/model_summaries.py @@ -32,6 +32,8 @@ import pandas as pd from tabulate import tabulate import pydot import distiller +import logging +msglogger = logging.getLogger(__name__) def onnx_name_2_pytorch_name(name, op_type): @@ -41,8 +43,8 @@ def onnx_name_2_pytorch_name(name, op_type): # First see if there's an instance identifier instance = '' - if name.find('.')>0: - instance = name[name.find('.')+1 :] + if name.find('.') > 0: + instance = name[name.find('.')+1:] # Next, split by square brackets name_parts = re.findall('\[.*?\]', name) @@ -217,13 +219,12 @@ class SummaryGraph(object): op['attrs']['footprint'] = 0 if op['type'] in ['Conv', 'Gemm', 'MaxPool']: conv_out = op['outputs'][0] - conv_in = op['inputs'][0] + conv_in = op['inputs'][0] ofm_vol = self.param_volume(conv_out) ifm_vol = self.param_volume(conv_in) if op['type'] == 'Conv' or op['type'] == 'Gemm': conv_w = op['inputs'][1] weights_vol = self.param_volume(conv_w) - #print(ofm_vol , ifm_vol , weights_vol) op['attrs']['footprint'] = ofm_vol + ifm_vol + weights_vol op['attrs']['fm_vol'] = ofm_vol + ifm_vol op['attrs']['weights_vol'] = weights_vol @@ -238,10 +239,10 @@ class SummaryGraph(object): # integers are enough, and note that we also round up op['attrs']['ai'] = ((op['attrs']['MACs']+0.5*op['attrs']['footprint']) // op['attrs']['footprint']) - def get_attr(self, attr, f = lambda op: True): + def get_attr(self, attr, f=lambda op: True): return [op['attrs'][attr] for op in self.ops.values() if attr in op['attrs'] and f(op)] - def get_ops(self, attr, f = lambda op: True): + def get_ops(self, attr, f=lambda op: True): return [op for op in self.ops.values() if attr in op['attrs'] and f(op)] def find_op(self, lost_op_name): @@ -262,8 +263,8 @@ class SummaryGraph(object): edge.src not in done_list)] done_list += preds else: - preds = [edge.src for edge in self.edges if (edge.dst == op and - edge.src not in done_list)] + preds = [edge.src for edge in self.edges if (edge.dst == op and + edge.src not in done_list)] done_list += preds if depth == 1: @@ -271,7 +272,7 @@ class SummaryGraph(object): else: ret = [] for predecessor in preds: - ret += self.predecessors(predecessor, depth-1, done_list) #, logging) + ret += self.predecessors(predecessor, depth-1, done_list) return ret def predecessors_f(self, node_name, predecessors_types, done_list=None, logging=None): @@ -283,6 +284,7 @@ class SummaryGraph(object): node_is_an_op = False node = self.find_param(node_name) if node is None: + msglogger.warn("predecessors_f: Could not find node {}".format(node_name)) return [] if done_list is None: -- GitLab