diff --git a/distiller/model_transforms.py b/distiller/model_transforms.py index 647d2ed5872121b68ed220c90d54c852914f0787..015a1ffa380e33021e7a8c10a706ad491082bd4e 100644 --- a/distiller/model_transforms.py +++ b/distiller/model_transforms.py @@ -131,7 +131,9 @@ def fold_batch_norms(model, dummy_input=None, adjacency_map=None, inference=True foldables = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d) batchnorms = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, FrozenBatchNorm2d) - return fuse_modules(model, (foldables, batchnorms), fold_bn, dummy_input, adjacency_map) + if any([isinstance(m, batchnorms) for m in model.modules()]): + return fuse_modules(model, (foldables, batchnorms), fold_bn, dummy_input, adjacency_map) + return model def _fuse_sequence(sequence, named_modules, fuse_fn):