diff --git a/distiller/models/mnist/simplenet_mnist.py b/distiller/models/mnist/simplenet_mnist.py
index 39b01e9d24aae0602ed14d3d5e0e23af2896bdb2..07270432d1e3affaa0e290d18cca608e2c2f5680 100755
--- a/distiller/models/mnist/simplenet_mnist.py
+++ b/distiller/models/mnist/simplenet_mnist.py
@@ -23,27 +23,64 @@ import torch.nn as nn
 import torch.nn.functional as F
 
 
-__all__ = ['simplenet_mnist']
+__all__ = ['simplenet_mnist', 'simplenet_v2_mnist']
 
 
 class Simplenet(nn.Module):
     def __init__(self):
         super().__init__()
         self.conv1 = nn.Conv2d(1, 20, 5, 1)
+        self.relu1 = nn.ReLU(inplace=False)
+        self.pool1 = nn.MaxPool2d(2, 2)
         self.conv2 = nn.Conv2d(20, 50, 5, 1)
+        self.relu2 = nn.ReLU(inplace=False)
+        self.pool2 = nn.MaxPool2d(2, 2)
         self.fc1 = nn.Linear(4*4*50, 500)
+        self.relu3 = nn.ReLU(inplace=False)
         self.fc2 = nn.Linear(500, 10)
         
     def forward(self, x):
-        x = F.relu(self.conv1(x))
-        x = F.max_pool2d(x, 2, 2)
-        x = F.relu(self.conv2(x))
-        x = F.max_pool2d(x, 2, 2)
+        x = self.pool1(self.relu1(self.conv1(x)))
+        x = self.pool2(self.relu2(self.conv2(x)))
         x = x.view(x.size(0), -1)
-        x = F.relu(self.fc1(x))
+        x = self.relu3(self.fc1(x))
         x = self.fc2(x)
-        return F.log_softmax(x, dim=1)
-        
+        return x
+
+
+class Simplenet_v2(nn.Module):
+    """
+    This is Simplenet but with only one small Linear layer, instead of two Linear layers,
+    one of which is large.
+    26K parameters.
+    python compress_classifier.py ${MNIST_PATH} --arch=simplenet_mnist --vs=0 --lr=0.01
+
+    ==> Best [Top1: 98.970   Top5: 99.970   Sparsity:0.00   Params: 26000 on epoch: 54]
+    """
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 20, 5, 1)
+        self.relu1 = nn.ReLU(inplace=False)
+        self.pool1 = nn.MaxPool2d(2, 2)
+        self.conv2 = nn.Conv2d(20, 50, 5, 1)
+        self.relu2 = nn.ReLU(inplace=False)
+        self.pool2 = nn.MaxPool2d(2, 2)
+        self.avgpool = nn.AvgPool2d(4, stride=1)
+        self.fc = nn.Linear(50, 10)
+
+    def forward(self, x):
+        x = self.pool1(self.relu1(self.conv1(x)))
+        x = self.pool2(self.relu2(self.conv2(x)))
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        x = self.fc(x)
+        return x
+
+
 def simplenet_mnist():
     model = Simplenet()
     return model
+
+def simplenet_v2_mnist():
+    model = Simplenet_v2()
+    return model
\ No newline at end of file