Skip to content
Snippets Groups Projects
Commit fbb8d486 authored by Guy Jacob's avatar Guy Jacob
Browse files

BN folding - do nothing if no BNs in model

parent a0b38e2d
No related branches found
No related tags found
No related merge requests found
...@@ -131,7 +131,9 @@ def fold_batch_norms(model, dummy_input=None, adjacency_map=None, inference=True ...@@ -131,7 +131,9 @@ def fold_batch_norms(model, dummy_input=None, adjacency_map=None, inference=True
foldables = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d) foldables = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)
batchnorms = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, FrozenBatchNorm2d) batchnorms = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, FrozenBatchNorm2d)
return fuse_modules(model, (foldables, batchnorms), fold_bn, dummy_input, adjacency_map) if any([isinstance(m, batchnorms) for m in model.modules()]):
return fuse_modules(model, (foldables, batchnorms), fold_bn, dummy_input, adjacency_map)
return model
def _fuse_sequence(sequence, named_modules, fuse_fn): def _fuse_sequence(sequence, named_modules, fuse_fn):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment