diff --git a/hpvm/test/epoch_dnn/main.py b/hpvm/test/epoch_dnn/main.py
index 221c1e701915ab115748a7a1c0a0fbdaf792e0d5..b89f847d68fc06db9f1bca3542060637a3af6c56 100644
--- a/hpvm/test/epoch_dnn/main.py
+++ b/hpvm/test/epoch_dnn/main.py
@@ -10,7 +10,8 @@ from torch2hpvm import BinDataset, ModelExporter
 self_folder = Path(__file__).parent.absolute()
 site.addsitedir(self_folder.as_posix())
 
-from torch_dnn import CIFAR, MiniERA, quantize
+from torch_dnn import quantize
+from torch_dnn.miniera import CIFAR, MiniERA
 
 # Consts (don't change)
 BUFFER_NAME = "hpvm-mod.nvdla"
diff --git a/hpvm/test/epoch_dnn/torch_dnn/__init__.py b/hpvm/test/epoch_dnn/torch_dnn/__init__.py
index a6fb99465f19b9e49e2b656d696a984eca04df42..89b65549bac66433f39965c5958c382786734637 100644
--- a/hpvm/test/epoch_dnn/torch_dnn/__init__.py
+++ b/hpvm/test/epoch_dnn/torch_dnn/__init__.py
@@ -1,3 +1 @@
-from .datasets import CIFAR
-from .miniera import MiniERA
 from .quantizer import quantize
diff --git a/hpvm/test/epoch_dnn/torch_dnn/miniera/__init__.py b/hpvm/test/epoch_dnn/torch_dnn/miniera/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d135d1bd401e89b8852e49d0d965b8ba47c87bf8
--- /dev/null
+++ b/hpvm/test/epoch_dnn/torch_dnn/miniera/__init__.py
@@ -0,0 +1,2 @@
+from .dataset import CIFAR
+from .model import MiniERA
diff --git a/hpvm/test/epoch_dnn/torch_dnn/datasets.py b/hpvm/test/epoch_dnn/torch_dnn/miniera/dataset.py
similarity index 89%
rename from hpvm/test/epoch_dnn/torch_dnn/datasets.py
rename to hpvm/test/epoch_dnn/torch_dnn/miniera/dataset.py
index ac519bfe26ba48429cf40d6f16fdc3587e606b37..b235a7ca5c2ffb5e953a8fbacb907c5c72b4f39c 100644
--- a/hpvm/test/epoch_dnn/torch_dnn/datasets.py
+++ b/hpvm/test/epoch_dnn/torch_dnn/miniera/dataset.py
@@ -7,39 +7,16 @@ import torch
 from torch.utils.data.dataset import Dataset
 
 RetT = Tuple[torch.Tensor, torch.Tensor]
-msg_logger = logging.getLogger(__name__)
-
 PathLike = Union[Path, str]
+msg_logger = logging.getLogger(__name__)
 
 
 class SingleFileDataset(Dataset):
+    image_shape = None
+
     def __init__(self, inputs: torch.Tensor, outputs: torch.Tensor):
         self.inputs, self.outputs = inputs, outputs
 
-    @classmethod
-    def from_file(cls, *args, **kwargs):
-        pass
-
-    @property
-    def sample_input(self):
-        inputs, outputs = next(iter(self))
-        return inputs
-
-    def __len__(self) -> int:
-        return len(self.inputs)
-
-    def __getitem__(self, idx) -> RetT:
-        return self.inputs[idx], self.outputs[idx]
-
-    def __iter__(self) -> Iterator[RetT]:
-        for i in range(len(self)):
-            yield self[i]
-
-
-class DNNDataset(SingleFileDataset):
-    image_shape = None
-    label_ty = np.int32
-
     @classmethod
     def from_file(
         cls,
@@ -61,8 +38,8 @@ class DNNDataset(SingleFileDataset):
         labels = read_tensor_from_file(
             labels_file,
             -1,
-            read_ty=cls.label_ty,
-            cast_ty=np.long,
+            read_ty=np.int32,
+            cast_ty=np.int64,
             count=count,
             offset=offset,
         )
@@ -71,16 +48,31 @@ class DNNDataset(SingleFileDataset):
         msg_logger.info(f"%d entries loaded from dataset.", inputs.shape[0])
         return cls(inputs, labels)
 
+    @property
+    def sample_input(self):
+        inputs, outputs = next(iter(self))
+        return inputs
+
+    def __len__(self) -> int:
+        return len(self.inputs)
+
+    def __getitem__(self, idx) -> RetT:
+        return self.inputs[idx], self.outputs[idx]
+
+    def __iter__(self) -> Iterator[RetT]:
+        for i in range(len(self)):
+            yield self[i]
+
 
-class MNIST(DNNDataset):
+class MNIST(SingleFileDataset):
     image_shape = 1, 28, 28
 
 
-class CIFAR(DNNDataset):
+class CIFAR(SingleFileDataset):
     image_shape = 3, 32, 32
 
 
-class ImageNet(DNNDataset):
+class ImageNet(SingleFileDataset):
     image_shape = 3, 224, 224
 
 
diff --git a/hpvm/test/epoch_dnn/torch_dnn/miniera.py b/hpvm/test/epoch_dnn/torch_dnn/miniera/model.py
similarity index 94%
rename from hpvm/test/epoch_dnn/torch_dnn/miniera.py
rename to hpvm/test/epoch_dnn/torch_dnn/miniera/model.py
index 3e200673189185b0a5e3c19d037ca47e4ba0154a..0f6fe455d8ebf56f68886d54749c5e9b29bdf9c8 100644
--- a/hpvm/test/epoch_dnn/torch_dnn/miniera.py
+++ b/hpvm/test/epoch_dnn/torch_dnn/miniera/model.py
@@ -36,7 +36,9 @@ class MiniERA(Module):
         for conv in self.convs:
             if not isinstance(conv, Conv2d):
                 continue
-            weight_np = np.fromfile(prefix / f"conv2d_{count+1}_w.bin", dtype=np.float32)
+            weight_np = np.fromfile(
+                prefix / f"conv2d_{count+1}_w.bin", dtype=np.float32
+            )
             bias_np = np.fromfile(prefix / f"conv2d_{count+1}_b.bin", dtype=np.float32)
             conv.weight.data = torch.tensor(weight_np).reshape(conv.weight.shape)
             conv.bias.data = torch.tensor(bias_np).reshape(conv.bias.shape)
diff --git a/hpvm/test/epoch_dnn/torch_dnn/quantizer.py b/hpvm/test/epoch_dnn/torch_dnn/quantizer.py
index 729e8bd35671df179525e17d3028ec6a7a5bc05d..85ee9c479b114a8d14ad27a5f0d511311fc9af1f 100644
--- a/hpvm/test/epoch_dnn/torch_dnn/quantizer.py
+++ b/hpvm/test/epoch_dnn/torch_dnn/quantizer.py
@@ -1,17 +1,17 @@
 import os
 from copy import deepcopy
 from pathlib import Path
-from typing import Union
 from shutil import move
+from typing import Union
 
 import distiller
 import torch
-from torch.utils.data.dataset import Dataset
 import yaml
 from distiller.data_loggers import collect_quant_stats
 from distiller.quantization import PostTrainLinearQuantizer
 from torch import nn
 from torch.utils.data import DataLoader
+from torch.utils.data.dataset import Dataset
 
 PathLike = Union[str, Path]
 STATS_FILENAME = "acts_quantization_stats.yaml"
@@ -40,7 +40,7 @@ def quantize(
     strat: str = "NONE",
     working_dir: PathLike = ".",
     output_name: str = "calib.txt",
-    eval_batchsize: int = 128
+    eval_batchsize: int = 128,
 ):
     # possible quant strats ['NONE', 'AVG', 'N_STD', 'GAUSS', 'LAPLACE']
     print("Quantizing...")