diff --git a/distiller/models/cifar10/plain_cifar.py b/distiller/models/cifar10/plain_cifar.py
index c6a701980afb3096f071dcecc4e68cd126b68453..7f668a5e346d019beb171ff6b2fdbb3e4158053f 100755
--- a/distiller/models/cifar10/plain_cifar.py
+++ b/distiller/models/cifar10/plain_cifar.py
@@ -35,7 +35,7 @@ import torch.nn as nn
 import math
 
 
-__all__ = ['plain20_cifar']
+__all__ = ['plain20_cifar', 'plain20_cifar_nobn']
 
 NUM_CLASSES = 10
 
@@ -49,36 +49,36 @@ def conv3x3(in_planes, out_planes, stride=1):
 class BasicBlock(nn.Module):
     expansion = 1
 
-    def __init__(self, inplanes, planes, stride=1):
+    def __init__(self, inplanes, planes, stride=1, batch_norm=True):
         super().__init__()
         self.conv1 = conv3x3(inplanes, planes, stride)
-        self.bn1 = nn.BatchNorm2d(planes)
+        self.bn1 = nn.BatchNorm2d(planes) if batch_norm else None
         self.relu1 = nn.ReLU(inplace=False)
         self.conv2 = conv3x3(planes, planes)
-        self.bn2 = nn.BatchNorm2d(planes)
+        self.bn2 = nn.BatchNorm2d(planes) if batch_norm else None
         self.relu2 = nn.ReLU(inplace=False)
 
     def forward(self, x):
         out = self.conv1(x)
-        out = self.bn1(out)
+        out = self.bn1(out) if self.bn1 is not None else out
         out = self.relu1(out)
 
         out = self.conv2(out)
-        out = self.bn2(out)
+        out = self.bn2(out) if self.bn2 is not None else out
         out = self.relu2(out)
         return out
 
 
 class PlainCifar(nn.Module):
-    def __init__(self, block, blks_per_layer, num_classes=NUM_CLASSES):
+    def __init__(self, block, blks_per_layer, num_classes=NUM_CLASSES, batch_norm=True):
         self.inplanes = 16
         super().__init__()
         self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
-        self.bn1 = nn.BatchNorm2d(self.inplanes)
+        self.bn1 = nn.BatchNorm2d(self.inplanes) if batch_norm else None
         self.relu = nn.ReLU(inplace=True)
-        self.layer1 = self._make_layer(block, 16, blks_per_layer[0], stride=1)
-        self.layer2 = self._make_layer(block, 32, blks_per_layer[1], stride=2)
-        self.layer3 = self._make_layer(block, 64, blks_per_layer[2], stride=2)
+        self.layer1 = self._make_layer(block, 16, blks_per_layer[0], stride=1, batch_norm=batch_norm)
+        self.layer2 = self._make_layer(block, 32, blks_per_layer[1], stride=2, batch_norm=batch_norm)
+        self.layer3 = self._make_layer(block, 64, blks_per_layer[2], stride=2, batch_norm=batch_norm)
 
         self.avgpool = nn.AvgPool2d(8, stride=1)
         self.fc = nn.Linear(64 * block.expansion, num_classes)
@@ -91,22 +91,22 @@ class PlainCifar(nn.Module):
                 m.weight.data.fill_(1)
                 m.bias.data.zero_()
 
-    def _make_layer(self, block, planes, num_blocks, stride):
+    def _make_layer(self, block, planes, num_blocks, stride, batch_norm=True):
         # Each layer is composed on 2*num_blocks blocks, and the first block usually
         # performs downsampling of the input, and doubling of the number of filters/feature-maps.
         blocks = []
         inplanes = self.inplanes
         # First block is special (downsamples and adds filters)
-        blocks.append(block(inplanes, planes, stride))
+        blocks.append(block(inplanes, planes, stride, batch_norm=batch_norm))
 
         self.inplanes = planes * block.expansion
         for i in range(num_blocks - 1):
-            blocks.append(block(self.inplanes, planes, stride=1))
+            blocks.append(block(self.inplanes, planes, stride=1, batch_norm=batch_norm))
         return nn.Sequential(*blocks)
 
     def forward(self, x):
         x = self.conv1(x)
-        x = self.bn1(x)
+        x = self.bn1(x) if self.bn1 is not None else x
         x = self.relu(x)
 
         x = self.layer1(x)
@@ -120,5 +120,13 @@ class PlainCifar(nn.Module):
 
 
 def plain20_cifar(**kwargs):
+    # Plain20 for CIFAR10
     model = PlainCifar(BasicBlock, [3, 3, 3], **kwargs)
     return model
+    #return plain20_cifar_nobn(**kwargs)
+
+
+def plain20_cifar_nobn(**kwargs):
+    # Plain20 for CIFAR10, without batch-normalization layers
+    model = PlainCifar(BasicBlock, [3, 3, 3], batch_norm=False, **kwargs)
+    return model
\ No newline at end of file