diff --git a/distiller/model_transforms.py b/distiller/model_transforms.py
index 647d2ed5872121b68ed220c90d54c852914f0787..015a1ffa380e33021e7a8c10a706ad491082bd4e 100644
--- a/distiller/model_transforms.py
+++ b/distiller/model_transforms.py
@@ -131,7 +131,9 @@ def fold_batch_norms(model, dummy_input=None, adjacency_map=None, inference=True
 
     foldables = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)
     batchnorms = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, FrozenBatchNorm2d)
-    return fuse_modules(model, (foldables, batchnorms), fold_bn, dummy_input, adjacency_map)
+    if any([isinstance(m, batchnorms) for m in model.modules()]):
+        return fuse_modules(model, (foldables, batchnorms), fold_bn, dummy_input, adjacency_map)
+    return model
 
 
 def _fuse_sequence(sequence, named_modules, fuse_fn):