Skip to content
Snippets Groups Projects
Commit d08c3734 authored by Bar's avatar Bar Committed by Neta Zmora
Browse files

Relaxation of the SummaryGraph API (#212)

This commit simplifies the SummaryGraph API,
by removing from the client to burden to handle 
the differences between models with/without 
DataParallel layers.

DataParallel layers in PyTorch change the fully-qualified
names (FQNs) of PyTorch modules.  A module's FQN
unambiguously identifies a module within a model, by 
encoding the path to the module from the root of the 
model.  For example, ```module.layer2.1.conv1``` and 
```module.layer2.0.conv1``` are FQNs of two different
modules named ```conv1``` in some module.  
Because a module's FQN reflects the module's hierarchy,
adding/removing a DataParallel node also changes its FQN.

Distiller uses FQNs to refer to modules and parameters 
(e.g. from YAML files), and non-functional changes to the 
model hierarchy, such as using DataParallel modules are
handled by converting FQNs using `
``utils.{de,}normalize_module_name()```.

Before this commit, the SummaryGraph API assumed that 
the API client will convert layers names using 
```utils.normalize_module_name()``` before invoking the API.
This led to needlessly verbose client code, which was also
error-prone and harder to read and maintain.
This commit fixes these short-comings by relaxing the API, 
and handling the FQNN naming differences internally.

The thinning implementation is simplified somewhat
by refactoring to the new APIs lenient requirements.

Added named_params_layers method to SummaryGraph
that yields a 3-tuple of: layer name, param name, and param.
When using the new method, summary graph communicates the
true layer name in respect to the model it was initiated with.
parent ce082d5e
No related branches found
No related tags found
No related merge requests found
...@@ -93,12 +93,13 @@ class SummaryGraph(object): ...@@ -93,12 +93,13 @@ class SummaryGraph(object):
Edge = collections.namedtuple('Edge', 'src dst') Edge = collections.namedtuple('Edge', 'src dst')
def __init__(self, model, dummy_input): def __init__(self, model, dummy_input):
model = distiller.make_non_parallel_copy(model) self._src_model = model
with torch.onnx.set_training(model, False): model_clone = distiller.make_non_parallel_copy(model)
with torch.onnx.set_training(model_clone, False):
device = next(model.parameters()).device device = next(model_clone.parameters()).device
dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device) dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
trace, _ = jit.get_trace_graph(model, dummy_input) trace, _ = jit.get_trace_graph(model_clone, dummy_input)
# Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
# composing a GEMM operation; etc. # composing a GEMM operation; etc.
...@@ -152,7 +153,7 @@ class SummaryGraph(object): ...@@ -152,7 +153,7 @@ class SummaryGraph(object):
self.add_macs_attr() self.add_macs_attr()
self.add_footprint_attr() self.add_footprint_attr()
self.add_arithmetic_intensity_attr() self.add_arithmetic_intensity_attr()
del model del model_clone
def __create_op(self, onnx_node): def __create_op(self, onnx_node):
op = {} op = {}
...@@ -266,15 +267,13 @@ class SummaryGraph(object): ...@@ -266,15 +267,13 @@ class SummaryGraph(object):
return [op for op in self.ops.values() if attr in op['attrs'] and f(op)] return [op for op in self.ops.values() if attr in op['attrs'] and f(op)]
def find_op(self, lost_op_name): def find_op(self, lost_op_name):
assert isinstance(lost_op_name, str) return self.ops.get(distiller.normalize_module_name(lost_op_name), None)
return self.ops.get(lost_op_name, None)
def find_param(self, data_name): def find_param(self, data_name):
return self.params.get(data_name, None) return self.params.get(data_name, None)
def predecessors(self, op, depth, done_list=None): def predecessors(self, op, depth, done_list=None):
"""Returns a list of <op>'s predecessors""" """Returns a list of <op>'s predecessors"""
if done_list is None: if done_list is None:
done_list = [] done_list = []
...@@ -288,16 +287,18 @@ class SummaryGraph(object): ...@@ -288,16 +287,18 @@ class SummaryGraph(object):
done_list += preds done_list += preds
if depth == 1: if depth == 1:
return preds ret = preds
else: else:
ret = [] ret = []
for predecessor in preds: for predecessor in preds:
ret += self.predecessors(predecessor, depth-1, done_list) ret += self.predecessors(predecessor, depth-1, done_list)
return ret
return [distiller.denormalize_module_name(self._src_model, x) for x in ret]
def predecessors_f(self, node_name, predecessors_types, done_list=None, logging=None): def predecessors_f(self, node_name, predecessors_types, done_list=None, logging=None):
"""Returns a list of <op>'s predecessors, if they match the <predecessors_types> criteria. """Returns a list of <op>'s predecessors, if they match the <predecessors_types> criteria.
""" """
node_name = distiller.normalize_module_name(node_name)
node = self.find_op(node_name) node = self.find_op(node_name)
node_is_an_op = True node_is_an_op = True
if node is None: if node is None:
...@@ -319,7 +320,7 @@ class SummaryGraph(object): ...@@ -319,7 +320,7 @@ class SummaryGraph(object):
# We check if we found the type of node we're looking for, # We check if we found the type of node we're looking for,
# and that this is not the first node in our search. # and that this is not the first node in our search.
if node['type'] in predecessors_types and len(done_list) > 1: if node['type'] in predecessors_types and len(done_list) > 1:
return [node_name] return [distiller.denormalize_module_name(self._src_model, node_name)]
# This is an operation node # This is an operation node
preds = [edge.src for edge in self.edges if (edge.dst == node_name and preds = [edge.src for edge in self.edges if (edge.dst == node_name and
...@@ -331,11 +332,11 @@ class SummaryGraph(object): ...@@ -331,11 +332,11 @@ class SummaryGraph(object):
ret = [] ret = []
for predecessor in preds: for predecessor in preds:
ret += self.predecessors_f(predecessor, predecessors_types, done_list, logging) ret += self.predecessors_f(predecessor, predecessors_types, done_list, logging)
return ret
return [distiller.denormalize_module_name(self._src_model, node) for node in ret]
def successors(self, node, depth, done_list=None): def successors(self, node, depth, done_list=None):
"""Returns a list of <op>'s successors""" """Returns a list of <op>'s successors"""
if done_list is None: if done_list is None:
done_list = [] done_list = []
...@@ -351,12 +352,13 @@ class SummaryGraph(object): ...@@ -351,12 +352,13 @@ class SummaryGraph(object):
done_list += succs done_list += succs
if depth == 1: if depth == 1:
return succs ret = succs
else: else:
ret = [] ret = []
for successor in succs: for successor in succs:
ret += self.successors(successor, depth-1, done_list) ret += self.successors(successor, depth-1, done_list)
return ret
return [distiller.denormalize_module_name(self._src_model, x) for x in ret]
def successors_f(self, node_name, successors_types, done_list=None, logging=None): def successors_f(self, node_name, successors_types, done_list=None, logging=None):
"""Returns a list of <op>'s successors, if they match the <successors_types> criteria. """Returns a list of <op>'s successors, if they match the <successors_types> criteria.
...@@ -367,7 +369,7 @@ class SummaryGraph(object): ...@@ -367,7 +369,7 @@ class SummaryGraph(object):
<node_name> and the returned list of successors are strings, because <node_name> and the returned list of successors are strings, because
""" """
node_name = distiller.normalize_module_name(node_name)
node = self.find_op(node_name) node = self.find_op(node_name)
node_is_an_op = True node_is_an_op = True
if node is None: if node is None:
...@@ -389,7 +391,7 @@ class SummaryGraph(object): ...@@ -389,7 +391,7 @@ class SummaryGraph(object):
# We check if we found the type of node we're looking for, # We check if we found the type of node we're looking for,
# and that this is not the first node in our search. # and that this is not the first node in our search.
if node['type'] in successors_types and len(done_list) > 1: if node['type'] in successors_types and len(done_list) > 1:
return [node_name] return [distiller.denormalize_module_name(self._src_model, node_name)]
# This is an operation node # This is an operation node
succs = [edge.dst for edge in self.edges if (edge.src == node_name and succs = [edge.dst for edge in self.edges if (edge.src == node_name and
...@@ -401,4 +403,15 @@ class SummaryGraph(object): ...@@ -401,4 +403,15 @@ class SummaryGraph(object):
ret = [] ret = []
for successor in succs: for successor in succs:
ret += self.successors_f(successor, successors_types, done_list, logging) ret += self.successors_f(successor, successors_types, done_list, logging)
return ret
return [distiller.denormalize_module_name(self._src_model, node) for node in ret]
def named_params_layers(self):
for param_name, param in self._src_model.named_parameters():
# remove the extension of param_name, and then normalize it
# to create a normalized layer name
normalized_layer_name = distiller.normalize_module_name(
'.'.join(param_name.split('.')[:-1]))
sgraph_layer_name = distiller.denormalize_module_name(
self._src_model, normalized_layer_name)
yield sgraph_layer_name, param_name, param
...@@ -31,8 +31,6 @@ from collections import namedtuple ...@@ -31,8 +31,6 @@ from collections import namedtuple
import torch import torch
from .policy import ScheduledTrainingPolicy from .policy import ScheduledTrainingPolicy
import distiller import distiller
from distiller import normalize_module_name, denormalize_module_name
from distiller.models import create_model
from .summary_graph import SummaryGraph from .summary_graph import SummaryGraph
msglogger = logging.getLogger(__name__) msglogger = logging.getLogger(__name__)
...@@ -64,34 +62,23 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers', ...@@ -64,34 +62,23 @@ __all__ = ['ThinningRecipe', 'resnet_cifar_remove_layers',
'execute_thinning_recipes_list', 'get_normalized_recipe'] 'execute_thinning_recipes_list', 'get_normalized_recipe']
def create_graph(dataset, arch): def create_graph(dataset, model):
dummy_input = None
if dataset == 'imagenet': if dataset == 'imagenet':
dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False) dummy_input = torch.randn((1, 3, 224, 224), requires_grad=False)
elif dataset == 'cifar10': elif dataset == 'cifar10':
dummy_input = torch.randn((1, 3, 32, 32), requires_grad=False) dummy_input = torch.randn((1, 3, 32, 32), requires_grad=False)
assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset) assert dummy_input is not None, "Unsupported dataset ({}) - aborting draw operation".format(dataset)
model = create_model(False, dataset, arch, parallel=False)
assert model is not None
dummy_input = dummy_input.to(distiller.model_device(model)) dummy_input = dummy_input.to(distiller.model_device(model))
return SummaryGraph(model, dummy_input) return SummaryGraph(model, dummy_input)
def get_normalized_recipe(recipe): def get_normalized_recipe(recipe):
new_recipe = ThinningRecipe(modules={normalize_module_name(k): v for k, v in recipe.modules.items()}, return ThinningRecipe(
parameters={normalize_module_name(k): v for k, v in recipe.parameters.items()}) modules={distiller.normalize_module_name(k): v for k, v in recipe.modules.items()},
return new_recipe parameters={distiller.normalize_module_name(k): v for k, v in recipe.parameters.items()},
)
def param_name_2_layer_name(param_name):
"""Convert a weights tensor's name to the name of the layer using the tensor.
By convention, PyTorch modules name their weights parameters as self.weight
(see for example: torch.nn.modules.conv) which means that their fully-qualified
name when enumerating a model's parameters is the modules name followed by '.weight'.
We exploit this convention to convert a weights tensor name to the fully-qualified
module name."""
return param_name[:-len('.weight')]
def directives_equal(d1, d2): def directives_equal(d1, d2):
...@@ -120,9 +107,8 @@ def append_param_directive(thinning_recipe, param_name, directive): ...@@ -120,9 +107,8 @@ def append_param_directive(thinning_recipe, param_name, directive):
thinning_recipe.parameters[param_name] = param_directives thinning_recipe.parameters[param_name] = param_directives
def append_module_directive(model, thinning_recipe, module_name, key, val): def append_module_directive(thinning_recipe, module_name, key, val):
msglogger.debug("\t[recipe] setting {}.{} = {}".format(module_name, key, val)) msglogger.debug("\t[recipe] setting {}.{} = {}".format(module_name, key, val))
module_name = denormalize_module_name(model, module_name)
mod_directive = thinning_recipe.modules.get(module_name, {}) mod_directive = thinning_recipe.modules.get(module_name, {})
mod_directive[key] = val mod_directive[key] = val
thinning_recipe.modules[module_name] = mod_directive thinning_recipe.modules[module_name] = mod_directive
...@@ -180,7 +166,7 @@ def resnet_cifar_remove_layers(model): ...@@ -180,7 +166,7 @@ def resnet_cifar_remove_layers(model):
def remove_channels(model, zeros_mask_dict, arch, dataset, optimizer): def remove_channels(model, zeros_mask_dict, arch, dataset, optimizer):
sgraph = create_graph(dataset, arch) sgraph = create_graph(dataset, model)
thinning_recipe = create_thinning_recipe_channels(sgraph, model, zeros_mask_dict) thinning_recipe = create_thinning_recipe_channels(sgraph, model, zeros_mask_dict)
apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer) apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
return model return model
...@@ -234,7 +220,7 @@ def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer): ...@@ -234,7 +220,7 @@ def apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer):
def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer): def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer):
sgraph = create_graph(dataset, arch) sgraph = create_graph(dataset, model)
thinning_recipe = create_thinning_recipe_filters(sgraph, model, zeros_mask_dict) thinning_recipe = create_thinning_recipe_filters(sgraph, model, zeros_mask_dict)
apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer) apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
return model return model
...@@ -256,7 +242,7 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): ...@@ -256,7 +242,7 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
# Traverse all of the model's parameters, search for zero-channels, and # Traverse all of the model's parameters, search for zero-channels, and
# create a thinning recipe that descibes the required changes to the model. # create a thinning recipe that descibes the required changes to the model.
for param_name, param in model.named_parameters(): for layer_name, param_name, param in sgraph.named_params_layers():
# We are only interested in 4D weights (of Convolution layers) # We are only interested in 4D weights (of Convolution layers)
if param.dim() != 4: if param.dim() != 4:
continue continue
...@@ -272,43 +258,35 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict): ...@@ -272,43 +258,35 @@ def create_thinning_recipe_channels(sgraph, model, zeros_mask_dict):
# We are removing channels, so update the number of incoming channels (IFMs) # We are removing channels, so update the number of incoming channels (IFMs)
# in the convolutional layer # in the convolutional layer
layer_name = param_name_2_layer_name(param_name)
assert isinstance(layers[layer_name], torch.nn.modules.Conv2d) assert isinstance(layers[layer_name], torch.nn.modules.Conv2d)
append_module_directive(model, thinning_recipe, layer_name, key='in_channels', val=num_nnz_channels) append_module_directive(thinning_recipe, layer_name, key='in_channels', val=num_nnz_channels)
# Select only the non-zero filters # Select only the non-zero filters
indices = nonzero_channels.data.squeeze() indices = nonzero_channels.data.squeeze()
append_param_directive(thinning_recipe, param_name, (1, indices)) append_param_directive(thinning_recipe, param_name, (1, indices))
# Find all instances of Convolution layers that immediately preceed this layer # Find all instances of Convolution layers that immediately preceed this layer
predecessors = sgraph.predecessors_f(normalize_module_name(layer_name), ['Conv']) predecessors = sgraph.predecessors_f(layer_name, ['Conv'])
# Convert the layers names to PyTorch's convoluted naming scheme (when DataParallel is used)
predecessors = [normalize_module_name(predecessor) for predecessor in predecessors]
if len(predecessors) == 0: if len(predecessors) == 0:
msglogger.info("Could not find predecessors for name={} normal={} {}".format( msglogger.info("Could not find predecessors for name={}".format(layer_name))
layer_name, normalize_module_name(layer_name), denormalize_module_name(model, layer_name)))
for predecessor in predecessors: for predecessor in predecessors:
# For each of the convolutional layers that preceed, we have to reduce the number of output channels. # For each of the convolutional layers that preceed, we have to reduce the number of output channels.
append_module_directive(model, thinning_recipe, predecessor, key='out_channels', val=num_nnz_channels) append_module_directive(thinning_recipe, predecessor, key='out_channels', val=num_nnz_channels)
# Now remove channels from the weights tensor of the predecessor conv # Now remove channels from the weights tensor of the predecessor conv
append_param_directive(thinning_recipe, denormalize_module_name(model, predecessor)+'.weight', (0, indices)) append_param_directive(thinning_recipe, predecessor+'.weight', (0, indices))
if layers[denormalize_module_name(model, predecessor)].bias is not None: if layers[predecessor].bias is not None:
# This convolution has bias coefficients # This convolution has bias coefficients
append_param_directive(thinning_recipe, denormalize_module_name(model, predecessor)+'.bias', (0, indices)) append_param_directive(thinning_recipe, predecessor+'.bias', (0, indices))
# Now handle the BatchNormalization layer that follows the convolution # Now handle the BatchNormalization layer that follows the convolution
bn_layers = sgraph.predecessors_f(normalize_module_name(layer_name), ['BatchNormalization']) bn_layers = sgraph.predecessors_f(layer_name, ['BatchNormalization'])
if len(bn_layers) > 0: for bn_layer in bn_layers:
# if len(bn_layers) != 1: # Thinning of the BN layer that follows the convolution
# raise RuntimeError("{} should have exactly one BN predecessors, but has {}".format(layer_name, len(bn_layers))) msglogger.debug("[recipe] {}: predecessor BN module = {}".format(layer_name, bn_layer))
for bn_layer in bn_layers: append_bn_thinning_directive(thinning_recipe, layers, bn_layer,
# Thinning of the BN layer that follows the convolution len_thin_features=num_nnz_channels, thin_features=indices)
bn_layer_name = denormalize_module_name(model, bn_layer)
msglogger.debug("[recipe] {}: predecessor BN module = {}".format(layer_name, bn_layer_name))
append_bn_thinning_directive(thinning_recipe, layers, bn_layer_name,
len_thin_features=num_nnz_channels, thin_features=indices)
msglogger.debug(thinning_recipe) msglogger.debug(thinning_recipe)
return thinning_recipe return thinning_recipe
...@@ -329,7 +307,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): ...@@ -329,7 +307,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
thinning_recipe = ThinningRecipe(modules={}, parameters={}) thinning_recipe = ThinningRecipe(modules={}, parameters={})
layers = {mod_name: m for mod_name, m in model.named_modules()} layers = {mod_name: m for mod_name, m in model.named_modules()}
for param_name, param in model.named_parameters(): for layer_name, param_name, param in sgraph.named_params_layers():
# We are only interested in 4D weights # We are only interested in 4D weights
if param.dim() != 4: if param.dim() != 4:
continue continue
...@@ -343,7 +321,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): ...@@ -343,7 +321,7 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
raise ValueError("Trying to set zero filters for parameter %s is not allowed" % param_name) raise ValueError("Trying to set zero filters for parameter %s is not allowed" % param_name)
# If there are non-zero filters in this tensor then continue to next tensor # If there are non-zero filters in this tensor then continue to next tensor
if num_filters <= num_nnz_filters: if num_filters <= num_nnz_filters:
msglogger.debug("Skipping {} shape={}".format(param_name_2_layer_name(param_name), param.shape)) msglogger.debug("Skipping {} shape={}".format(param_name, param.shape))
continue continue
msglogger.info("In tensor %s found %d/%d zero filters", param_name, msglogger.info("In tensor %s found %d/%d zero filters", param_name,
...@@ -351,9 +329,8 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): ...@@ -351,9 +329,8 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
# We are removing filters, so update the number of outgoing channels (OFMs) # We are removing filters, so update the number of outgoing channels (OFMs)
# in the convolutional layer # in the convolutional layer
layer_name = param_name_2_layer_name(param_name)
assert isinstance(layers[layer_name], torch.nn.modules.Conv2d) assert isinstance(layers[layer_name], torch.nn.modules.Conv2d)
append_module_directive(model, thinning_recipe, layer_name, key='out_channels', val=num_nnz_filters) append_module_directive(thinning_recipe, layer_name, key='out_channels', val=num_nnz_filters)
# Select only the non-zero filters # Select only the non-zero filters
indices = nonzero_filters.data.squeeze() indices = nonzero_filters.data.squeeze()
...@@ -364,24 +341,20 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): ...@@ -364,24 +341,20 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
append_param_directive(thinning_recipe, layer_name+'.bias', (0, indices)) append_param_directive(thinning_recipe, layer_name+'.bias', (0, indices))
# Find all instances of Convolution or FC (GEMM) layers that immediately follow this layer # Find all instances of Convolution or FC (GEMM) layers that immediately follow this layer
msglogger.debug("{} => {}".format(layer_name, normalize_module_name(layer_name))) successors = sgraph.successors_f(layer_name, ['Conv', 'Gemm'])
successors = sgraph.successors_f(normalize_module_name(layer_name), ['Conv', 'Gemm'])
# Convert the layers names to PyTorch's convoluted naming scheme (when DataParallel is used)
successors = [denormalize_module_name(model, successor) for successor in successors]
for successor in successors: for successor in successors:
if isinstance(layers[successor], torch.nn.modules.Conv2d): if isinstance(layers[successor], torch.nn.modules.Conv2d):
# For each of the convolutional layers that follow, we have to reduce the number of input channels. # For each of the convolutional layers that follow, we have to reduce the number of input channels.
append_module_directive(model, thinning_recipe, successor, key='in_channels', val=num_nnz_filters) append_module_directive(thinning_recipe, successor, key='in_channels', val=num_nnz_filters)
# Now remove channels from the weights tensor of the successor conv # Now remove channels from the weights tensor of the successor conv
append_param_directive(thinning_recipe, denormalize_module_name(model, successor)+'.weight', (1, indices)) append_param_directive(thinning_recipe, successor+'.weight', (1, indices))
elif isinstance(layers[successor], torch.nn.modules.Linear): elif isinstance(layers[successor], torch.nn.modules.Linear):
# If a Linear (Fully-Connected) layer follows, we need to update it's in_features member # If a Linear (Fully-Connected) layer follows, we need to update it's in_features member
fm_size = layers[successor].in_features // layers[layer_name].out_channels fm_size = layers[successor].in_features // layers[layer_name].out_channels
in_features = fm_size * num_nnz_filters in_features = fm_size * num_nnz_filters
append_module_directive(model, thinning_recipe, successor, key='in_features', val=in_features) append_module_directive(thinning_recipe, successor, key='in_features', val=in_features)
msglogger.debug("[recipe] Linear {}: fm_size = {} layers[{}].out_channels={}".format( msglogger.debug("[recipe] Linear {}: fm_size = {} layers[{}].out_channels={}".format(
successor, in_features, layer_name, layers[layer_name].out_channels)) successor, in_features, layer_name, layers[layer_name].out_channels))
msglogger.debug("[recipe] {}: setting in_features = {}".format(successor, in_features)) msglogger.debug("[recipe] {}: setting in_features = {}".format(successor, in_features))
...@@ -391,18 +364,16 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict): ...@@ -391,18 +364,16 @@ def create_thinning_recipe_filters(sgraph, model, zeros_mask_dict):
fm_height = fm_width = int(math.sqrt(fm_size)) fm_height = fm_width = int(math.sqrt(fm_size))
view_4D = (layers[successor].out_features, layers[layer_name].out_channels, fm_height, fm_width) view_4D = (layers[successor].out_features, layers[layer_name].out_channels, fm_height, fm_width)
view_2D = (layers[successor].out_features, in_features) view_2D = (layers[successor].out_features, in_features)
append_param_directive(thinning_recipe, append_param_directive(thinning_recipe, successor+'.weight',
denormalize_module_name(model, successor)+'.weight',
(1, indices, view_4D, view_2D)) (1, indices, view_4D, view_2D))
# Now handle the BatchNormalization layer that follows the convolution # Now handle the BatchNormalization layer that follows the convolution
bn_layers = sgraph.successors_f(normalize_module_name(layer_name), ['BatchNormalization']) bn_layers = sgraph.successors_f(layer_name, ['BatchNormalization'])
if len(bn_layers) > 0: if len(bn_layers) > 0:
assert len(bn_layers) == 1 assert len(bn_layers) == 1
# Thinning of the BN layer that follows the convolution # Thinning of the BN layer that follows the convolution
bn_layer_name = denormalize_module_name(model, bn_layers[0]) append_bn_thinning_directive(thinning_recipe, layers, bn_layers[0],
append_bn_thinning_directive(thinning_recipe, layers, bn_layer_name, len_thin_features=num_nnz_filters, thin_features=indices)
len_thin_features=num_nnz_filters, thin_features=indices)
return thinning_recipe return thinning_recipe
......
...@@ -161,6 +161,23 @@ def test_normalize_module_name(): ...@@ -161,6 +161,23 @@ def test_normalize_module_name():
name_test('imagenet', 'alexnet') name_test('imagenet', 'alexnet')
def named_params_layers_test_aux(dataset, arch, dataparallel:bool):
model = create_model(False, dataset, arch, parallel=dataparallel)
sgraph = SummaryGraph(model, get_input(dataset))
sgraph_layer_names = set(k for k, i, j in sgraph.named_params_layers())
for layer_name in sgraph_layer_names:
assert (sgraph.find_op(layer_name) is not None,
'{} was not found in summary graph'.format(layer_name))
def test_named_params_layers():
for dataParallelModel in (True, False):
named_params_layers_test_aux('imagenet', 'vgg19', dataParallelModel)
named_params_layers_test_aux('cifar10', 'resnet20_cifar', dataParallelModel)
named_params_layers_test_aux('imagenet', 'alexnet', dataParallelModel)
named_params_layers_test_aux('imagenet', 'resnext101_32x4d', dataParallelModel)
def test_onnx_name_2_pytorch_name(): def test_onnx_name_2_pytorch_name():
assert "layer3.0.relu1" == onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1", 'Relu') assert "layer3.0.relu1" == onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1", 'Relu')
assert "features.34" == onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]', 'Conv') assert "features.34" == onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]', 'Conv')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment