diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py index 7ad42de3c42d75fec2b568fc427026cfefde88ce..5ee06fd93cc52096e45597b3c57ec1db38c7a123 100755 --- a/distiller/model_summaries.py +++ b/distiller/model_summaries.py @@ -165,8 +165,7 @@ def conv_visitor(self, input, output, df, model, memo): assert isinstance(self, torch.nn.Conv2d) if self in memo: return - - weights_vol = self.out_channels * self.in_channels * self.kernel_size[0] * self.kernel_size[1] + weights_vol = distiller.volume(self.weight) # Multiply-accumulate operations: MACs = volume(OFM) * (#IFM * K^2) / #Groups # Bias is ignored @@ -183,7 +182,7 @@ def fc_visitor(self, input, output, df, model, memo): # Multiply-accumulate operations: MACs = #IFM * #OFM # Bias is ignored - weights_vol = macs = self.in_features * self.out_features + weights_vol = macs = distiller.volume(self.weight) module_visitor(self, input, output, df, model, weights_vol, macs) @@ -436,10 +435,12 @@ def draw_img_classifier_to_file(model, png_fname, dataset=None, display_param_no styles['conv1'] = {'shape': 'oval', 'fillcolor': 'gray', 'style': 'rounded, filled'} - input_shape (tuple): List of integers representing the input shape. Used only if 'dataset' is None + input_shape (tuple): List of integers representing the input shape. + Used only if 'dataset' is None """ dummy_input = distiller.get_dummy_input(dataset=dataset, - device=distiller.model_device(model), input_shape=input_shape) + device=distiller.model_device(model), + input_shape=input_shape) try: non_para_model = distiller.make_non_parallel_copy(model) g = SummaryGraph(non_para_model, dummy_input)