From 4b1d0c89de9cb838ee77df2128d2c9f27f75a863 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Sun, 10 Feb 2019 12:52:07 +0200
Subject: [PATCH] Load different random subset of dataset on each epoch (#149)

* For CIFAR-10 / ImageNet only
* Refactor data_loaders.py, reduce code duplication
* Implemented custom sampler
* Integrated in image classification sample
* Since we now shuffle the test set, had to update expected results
  in 2 full_flow_tests that do evaluation
---
 apputils/data_loaders.py                      | 195 ++++++++++--------
 .../resnet20_filters.schedule_agp.yaml        |   2 +-
 .../resnet20_filters.schedule_agp_3.yaml      |   2 +-
 .../resnet20_filters.schedule_agp_4.yaml      |   2 +-
 .../resnet50.schedule_agp.1x1x8-blocks.yaml   |   2 +-
 .../resnet50.schedule_agp.filters.yaml        |   2 +-
 .../resnet50.schedule_agp.filters_2.yaml      |   2 +-
 .../resnet50.schedule_agp.filters_3.yaml      |   2 +-
 ...resnet50.schedule_agp.filters_with_FC.yaml |   2 +-
 .../compress_classifier.py                    |   6 +-
 examples/classifier_compression/parser.py     |  31 ++-
 .../resnet20.network_surgery.yaml             |   2 +-
 .../resnet50.network_surgery.yaml             |   2 +-
 .../resnet50.network_surgery2.yaml            |   2 +-
 .../resnet50.filters.activation_apoz_agp.yaml |   2 +-
 tests/full_flow_tests.py                      |   4 +-
 16 files changed, 147 insertions(+), 113 deletions(-)

diff --git a/apputils/data_loaders.py b/apputils/data_loaders.py
index 0771bef..96802af 100755
--- a/apputils/data_loaders.py
+++ b/apputils/data_loaders.py
@@ -23,13 +23,14 @@ import os
 import torch
 import torchvision.transforms as transforms
 import torchvision.datasets as datasets
-from torch.utils.data.sampler import SubsetRandomSampler
+from torch.utils.data.sampler import Sampler
 import numpy as np
 
 DATASETS_NAMES = ['imagenet', 'cifar10']
 
 
-def load_data(dataset, data_dir, batch_size, workers, valid_size=0.1, deterministic=False):
+def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, deterministic=False,
+              effective_train_size=1., effective_valid_size=1., effective_test_size=1.):
     """Load a dataset.
 
     Args:
@@ -37,33 +38,22 @@ def load_data(dataset, data_dir, batch_size, workers, valid_size=0.1, determinis
         data_dir: the directory where the datset resides
         batch_size: the batch size
         workers: the number of worker threads to use for loading the data
-        valid_size: portion of training dataset to set aside for validation
+        validation_split: portion of training dataset to set aside for validation
         deterministic: set to True if you want the data loading process to be deterministic.
           Note that deterministic data loading suffers from poor performance.
+        effective_train/valid/test_size: portion of the datasets to load on each epoch.
+          The subset is chosen randomly each time. For the training and validation sets, this is applied AFTER
+          the split to those sets according to the validation_split parameter
     """
-    assert dataset in DATASETS_NAMES
-    if dataset == 'cifar10':
-        return cifar10_load_data(data_dir, batch_size, workers, valid_size=valid_size, deterministic=deterministic)
-    if dataset == 'imagenet':
-        return imagenet_load_data(data_dir, batch_size, workers, valid_size=valid_size, deterministic=deterministic)
-    print("FATAL ERROR: load_data does not support dataset %s" % dataset)
-    exit(1)
+    if dataset not in DATASETS_NAMES:
+        raise ValueError('load_data does not support dataset %s" % dataset')
+    datasets_fn = cifar10_get_datasets if dataset == 'cifar10' else imagenet_get_datasets
+    return get_data_loaders(datasets_fn, data_dir, batch_size, workers, validation_split=validation_split,
+                            deterministic=deterministic, effective_train_size=effective_train_size,
+                            effective_valid_size=effective_valid_size, effective_test_size=effective_test_size)
 
 
-def __image_size(dataset):
-    # un-squeeze is used here to add the batch dimension (value=1), which is missing
-    return dataset[0][0].unsqueeze(0).size()
-
-
-def __deterministic_worker_init_fn(worker_id, seed=0):
-    import random
-    import numpy
-    random.seed(seed)
-    numpy.random.seed(seed)
-    torch.manual_seed(seed)
-
-
-def cifar10_load_data(data_dir, batch_size, num_workers, valid_size=0.1, deterministic=False):
+def cifar10_get_datasets(data_dir):
     """Load the CIFAR10 dataset.
 
     The original training dataset is split into training and validation sets (code is
@@ -81,116 +71,139 @@ def cifar10_load_data(data_dir, batch_size, num_workers, valid_size=0.1, determi
     [1] C.-Y. Lee, S. Xie, P. Gallagher, Z. Zhang, and Z. Tu. Deeply Supervised Nets.
     arXiv:1409.5185, 2014
     """
-    transform = transforms.Compose([
+    train_transform = transforms.Compose([
         transforms.RandomCrop(32, padding=4),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])
 
-    transform_test = transforms.Compose([
+    train_dataset = datasets.CIFAR10(root=data_dir, train=True,
+                                     download=True, transform=train_transform)
+
+    test_transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])
 
-    train_dataset = datasets.CIFAR10(root=data_dir, train=True,
-                                     download=True, transform=transform)
+    test_dataset = datasets.CIFAR10(root=data_dir, train=False,
+                                    download=True, transform=test_transform)
 
-    num_train = len(train_dataset)
-    indices = list(range(num_train))
-    split = int(np.floor(valid_size * num_train))
+    return train_dataset, test_dataset
 
-    np.random.shuffle(indices)
 
-    train_idx, valid_idx = indices[split:], indices[:split]
-    train_sampler = SubsetRandomSampler(train_idx)
+def imagenet_get_datasets(data_dir):
+    """
+    Load the ImageNet dataset.
+    """
+    train_dir = os.path.join(data_dir, 'train')
+    test_dir = os.path.join(data_dir, 'val')
+    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+                                     std=[0.229, 0.224, 0.225])
 
-    worker_init_fn = __deterministic_worker_init_fn if deterministic else None
+    train_transform = transforms.Compose([
+        transforms.RandomResizedCrop(224),
+        transforms.RandomHorizontalFlip(),
+        transforms.ToTensor(),
+        normalize,
+    ])
 
-    train_loader = torch.utils.data.DataLoader(train_dataset,
-                                               batch_size=batch_size, sampler=train_sampler,
-                                               num_workers=num_workers, pin_memory=True,
-                                               worker_init_fn=worker_init_fn)
+    train_dataset = datasets.ImageFolder(train_dir, train_transform)
 
-    valid_loader = None
-    if split > 0:
-        valid_sampler = SubsetRandomSampler(valid_idx)
-        valid_loader = torch.utils.data.DataLoader(train_dataset,
-                                                   batch_size=batch_size, sampler=valid_sampler,
-                                                   num_workers=num_workers, pin_memory=True,
-                                                   worker_init_fn=worker_init_fn)
+    test_transform = transforms.Compose([
+        transforms.Resize(256),
+        transforms.CenterCrop(224),
+        transforms.ToTensor(),
+        normalize,
+    ])
 
-    testset = datasets.CIFAR10(root=data_dir, train=False,
-                               download=True, transform=transform_test)
+    test_dataset = datasets.ImageFolder(test_dir, test_transform)
 
-    test_loader = torch.utils.data.DataLoader(
-            testset, batch_size=batch_size, shuffle=False,
-            num_workers=num_workers, pin_memory=True)
+    return train_dataset, test_dataset
 
-    input_shape = __image_size(train_dataset)
 
-    # If validation split was 0 we use the test set as the validation set
-    return train_loader, valid_loader or test_loader, test_loader, input_shape
+def __image_size(dataset):
+    # un-squeeze is used here to add the batch dimension (value=1), which is missing
+    return dataset[0][0].unsqueeze(0).size()
+
+
+def __deterministic_worker_init_fn(worker_id, seed=0):
+    import random
+    import numpy
+    random.seed(seed)
+    numpy.random.seed(seed)
+    torch.manual_seed(seed)
+
+
+def __split_list(l, ratio):
+    split_idx = int(np.floor(ratio * len(l)))
+    return l[:split_idx], l[split_idx:]
 
 
-def imagenet_load_data(data_dir, batch_size, num_workers, valid_size=0.1, deterministic=False):
-    """Load the ImageNet dataset.
+class SwitchingSubsetRandomSampler(Sampler):
+    """Samples a random subset of elements from a data source, without replacement.
 
-    Somewhat unconventionally, we use the ImageNet validation dataset as our test dataset,
-    and split the training dataset for training and validation (90/10 by default).
+    The subset of elements is re-chosen randomly each time the sampler is enumerated
+
+    Args:
+        data_source (Dataset): dataset to sample from
+        subset_size (float): value in (0..1], representing the portion of dataset to sample at each enumeration.
     """
-    train_dir = os.path.join(data_dir, 'train')
-    test_dir = os.path.join(data_dir, 'val')
-    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
-                                     std=[0.229, 0.224, 0.225])
+    def __init__(self, data_source, subset_size):
+        if subset_size <= 0 or subset_size > 1:
+            raise ValueError('subset_size must be in (0..1]')
+        self.data_source = data_source
+        self.subset_length = int(np.floor(len(self.data_source) * subset_size))
+
+    def __iter__(self):
+        # Randomizing in the same way as in torch.utils.data.sampler.SubsetRandomSampler to maintain
+        # reproducibility with the previous data loaders implementation
+        indices = torch.randperm(len(self.data_source))
+        subset_indices = indices[:self.subset_length]
+        return (self.data_source[i] for i in subset_indices)
+
+    def __len__(self):
+        return self.subset_length
+
 
-    train_dataset = datasets.ImageFolder(
-        train_dir,
-        transforms.Compose([
-            transforms.RandomResizedCrop(224),
-            transforms.RandomHorizontalFlip(),
-            transforms.ToTensor(),
-            normalize,
-        ]))
+def get_data_loaders(datasets_fn, data_dir, batch_size, num_workers, validation_split=0.1, deterministic=False,
+                     effective_train_size=1., effective_valid_size=1., effective_test_size=1.):
+    train_dataset, test_dataset = datasets_fn(data_dir)
+
+    worker_init_fn = __deterministic_worker_init_fn if deterministic else None
 
     num_train = len(train_dataset)
     indices = list(range(num_train))
-    split = int(np.floor(valid_size * num_train))
 
-    # Note! We must shuffle the imagenet data because the files are ordered
-    # by class.  If we don't shuffle, the train and validation datasets will
-    # by mutually-exclusive
-    np.random.shuffle(indices)
+    # TODO: Switch to torch.utils.data.datasets.random_split()
 
-    train_idx, valid_idx = indices[split:], indices[:split]
-    train_sampler = SubsetRandomSampler(train_idx)
-
-    input_shape = __image_size(train_dataset)
+    # We shuffle indices here in case the data is arranged by class, in which case we'd would get mutually
+    # exclusive datasets if we didn't shuffle
+    np.random.shuffle(indices)
 
-    worker_init_fn = __deterministic_worker_init_fn if deterministic else None
+    valid_indices, train_indices = __split_list(indices, validation_split)
 
+    train_sampler = SwitchingSubsetRandomSampler(train_indices, effective_train_size)
     train_loader = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=batch_size, sampler=train_sampler,
                                                num_workers=num_workers, pin_memory=True,
                                                worker_init_fn=worker_init_fn)
 
     valid_loader = None
-    if split > 0:
-        valid_sampler = SubsetRandomSampler(valid_idx)
+    if valid_indices:
+        valid_sampler = SwitchingSubsetRandomSampler(valid_indices, effective_valid_size)
         valid_loader = torch.utils.data.DataLoader(train_dataset,
                                                    batch_size=batch_size, sampler=valid_sampler,
                                                    num_workers=num_workers, pin_memory=True,
                                                    worker_init_fn=worker_init_fn)
 
-    test_loader = torch.utils.data.DataLoader(
-        datasets.ImageFolder(test_dir, transforms.Compose([
-            transforms.Resize(256),
-            transforms.CenterCrop(224),
-            transforms.ToTensor(),
-            normalize,
-        ])),
-        batch_size=batch_size, shuffle=False,
-        num_workers=num_workers, pin_memory=True)
+    test_indices = list(range(len(test_dataset)))
+    test_sampler = SwitchingSubsetRandomSampler(test_indices, effective_test_size)
+    test_loader = torch.utils.data.DataLoader(test_dataset,
+                                              batch_size=batch_size, sampler=test_sampler,
+                                              num_workers=num_workers, pin_memory=True)
+
+    input_shape = __image_size(train_dataset)
 
     # If validation split was 0 we use the test set as the validation set
     return train_loader, valid_loader or test_loader, test_loader, input_shape
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
index 708c141..aa04ed5 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp.yaml
@@ -14,7 +14,7 @@
 #     Total sparsity: 41.10
 #     # of parameters: 120,000  (=55.7% of the baseline parameters)
 #
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --validation-size=0
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --validation-split=0
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml
index 9ac92a2..3879858 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp_3.yaml
@@ -14,7 +14,7 @@
 #     Total sparsity: 56.41%
 #     # of parameters: 95922  (=35.4% of the baseline parameters ==> 64.6% sparsity)
 #
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_3.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --validation-size=0
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.4 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_3.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --validation-split=0
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
diff --git a/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml b/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml
index 816ac04..3f85c2d 100755
--- a/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml
+++ b/examples/agp-pruning/resnet20_filters.schedule_agp_4.yaml
@@ -14,7 +14,7 @@
 #     Total sparsity: 39.66
 #     # of parameters: 78,776  (=29.1% of the baseline parameters)
 #
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --validation-size=0
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml -j=1 --deterministic --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --validation-split=0
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
diff --git a/examples/agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml b/examples/agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml
index 342492c..c7cf5c0 100755
--- a/examples/agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml
+++ b/examples/agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml
@@ -4,7 +4,7 @@
 #
 # Best Top1: 76.358 (epoch 72) vs. 76.15 baseline (+0.2%)
 #
-# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=../agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml --validation-size=0 --num-best-scores=10
+# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=../agp-pruning/resnet50.schedule_agp.1x1x8-blocks.yaml --validation-split=0 --num-best-scores=10
 #
 # Parameters:
 # +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
diff --git a/examples/agp-pruning/resnet50.schedule_agp.filters.yaml b/examples/agp-pruning/resnet50.schedule_agp.filters.yaml
index a654ba6..04e3500 100755
--- a/examples/agp-pruning/resnet50.schedule_agp.filters.yaml
+++ b/examples/agp-pruning/resnet50.schedule_agp.filters.yaml
@@ -5,7 +5,7 @@
 # No. of Parameters: 12,335,296 (of 25,502,912) = 43.37% dense (56.63% sparse)
 # Total MACs: 1,822,031,872 (of 4,089,184,256) = 44.56% compute = 2.24x
 #
-# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters.yaml --validation-size=0   --num-best-scores=10 --name="resnet50_filters_v3.2"
+# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters.yaml --validation-split=0   --num-best-scores=10 --name="resnet50_filters_v3.2"
 #
 # Parameters:
 # +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
diff --git a/examples/agp-pruning/resnet50.schedule_agp.filters_2.yaml b/examples/agp-pruning/resnet50.schedule_agp.filters_2.yaml
index e571319..a9b6505 100755
--- a/examples/agp-pruning/resnet50.schedule_agp.filters_2.yaml
+++ b/examples/agp-pruning/resnet50.schedule_agp.filters_2.yaml
@@ -5,7 +5,7 @@
 # No. of Parameters: 12,671,168 (of 25,502,912) = 49.69% dense (50.31% sparse)
 # Total MACs: 2,037,186,560 (of 4,089,184,256) = 49.82% compute = 2.01x
 #
-# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters_2.yaml --validation-size=0 --num-best-scores=10
+# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters_2.yaml --validation-split=0 --num-best-scores=10
 #
 # Parameters:
 # +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
diff --git a/examples/agp-pruning/resnet50.schedule_agp.filters_3.yaml b/examples/agp-pruning/resnet50.schedule_agp.filters_3.yaml
index d68c3fc..8c21bbe 100755
--- a/examples/agp-pruning/resnet50.schedule_agp.filters_3.yaml
+++ b/examples/agp-pruning/resnet50.schedule_agp.filters_3.yaml
@@ -5,7 +5,7 @@
 # No. of Parameters: 17,329,344 (of 25,502,912) = 67.95% dense (32.05% sparse)
 # Total MACs: 2,753,298,432 (of 4,089,184,256) = 67.33% compute = 1.49x
 #
-# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters_3.yaml --validation-size=0 --num-best-scores=10
+# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters_3.yaml --validation-split=0 --num-best-scores=10
 #
 # Parameters:
 # +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
diff --git a/examples/agp-pruning/resnet50.schedule_agp.filters_with_FC.yaml b/examples/agp-pruning/resnet50.schedule_agp.filters_with_FC.yaml
index 15cc334..1b20fa1 100755
--- a/examples/agp-pruning/resnet50.schedule_agp.filters_with_FC.yaml
+++ b/examples/agp-pruning/resnet50.schedule_agp.filters_with_FC.yaml
@@ -6,7 +6,7 @@
 # Best Top1: 74.564 (epoch 84) vs. 76.15 baseline (-1.6%)
 # No. of Parameters: 10,901,696 (of 25,502,912) = 42.74% dense (57.26% sparse)
 # Total MACs: 1,822,031,872 (of 4,089,184,256) = 44.56% compute = 2.24x
-# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ~/datasets/imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters.yaml --validation-size=0   --num-best-scores=10
+# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ~/datasets/imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters.yaml --validation-split=0   --num-best-scores=10
 #
 # Parameters:
 # +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 076ddf9..98ec9de 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -194,7 +194,8 @@ def main():
     # substring "_cifar", then cifar10 is used.
     train_loader, val_loader, test_loader, _ = apputils.load_data(
         args.dataset, os.path.expanduser(args.data), args.batch_size,
-        args.workers, args.validation_size, args.deterministic)
+        args.workers, args.validation_split, args.deterministic,
+        args.effective_train_size, args.effective_valid_size, args.effective_test_size)
     msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d',
                    len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler))
 
@@ -644,7 +645,8 @@ def automated_deep_compression(model, criterion, optimizer, loggers, args):
 
     train_loader, val_loader, test_loader, _ = apputils.load_data(
         args.dataset, os.path.expanduser(args.data), args.batch_size,
-        args.workers, args.validation_size, args.deterministic)
+        args.workers, args.validation_split, args.deterministic,
+        args.effective_train_size, args.effective_valid_size, args.effective_test_size)
 
     args.display_confusion = True
     validate_fn = partial(validate, val_loader=test_loader, criterion=criterion,
diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py
index 435c8a7..0bad76b 100755
--- a/examples/classifier_compression/parser.py
+++ b/examples/classifier_compression/parser.py
@@ -15,6 +15,7 @@
 #
 
 import argparse
+import operator
 
 import distiller
 import models
@@ -79,8 +80,19 @@ def getParser():
                         'Flag set => overrides the --gpus flag')
     parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name')
     parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='Path to dump logs and checkpoints')
-    parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1,
+    parser.add_argument('--validation-split', '--valid-size', '--vs', dest='validation_split',
+                        type=float_range(exc_max=True), default=0.1,
                         help='Portion of training dataset to set aside for validation')
+    parser.add_argument('--effective-train-size', '--etrs', type=float_range(exc_min=True), default=1.,
+                        help='Portion of training dataset to be used in each epoch. '
+                             'NOTE: If --validation-split is set, then the value of this argument is applied '
+                             'AFTER the train-validation split according to that argument')
+    parser.add_argument('--effective-valid-size', '--evs', type=float_range(exc_min=True), default=1.,
+                        help='Portion of validation dataset to be used in each epoch. '
+                             'NOTE: If --validation-split is set, then the value of this argument is applied '
+                             'AFTER the train-validation split according to that argument')
+    parser.add_argument('--effective-test-size', '--etes', type=float_range(exc_min=True), default=1.,
+                        help='Portion of test dataset to be used in each epoch')
     parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK')
     parser.add_argument('--adc-params', dest='ADC_params', default=None, help='temp HACK')
     parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true',
@@ -134,8 +146,15 @@ def getParser():
     return parser
 
 
-def float_range(val_str):
-    val = float(val_str)
-    if val < 0 or val >= 1:
-        raise argparse.ArgumentTypeError('Must be >= 0 and < 1 (received {0})'.format(val_str))
-    return val
+def float_range(min_val=0., max_val=1., exc_min=False, exc_max=False):
+    def checker(val_str):
+        val = float(val_str)
+        min_op, min_op_str = (operator.gt, '>') if exc_min else (operator.ge, '>=')
+        max_op, max_op_str = (operator.lt, '<') if exc_max else (operator.le, '<=')
+        if min_op(val, min_val) and max_op(val, max_val):
+            return val
+        raise argparse.ArgumentTypeError(
+            'Value must be {} {} and {} {} (received {})'.format(min_op_str, min_val, max_op_str, max_val, val))
+    if min_val >= max_val:
+        raise ValueError('min_val must be less than max_val')
+    return checker
diff --git a/examples/network_surgery/resnet20.network_surgery.yaml b/examples/network_surgery/resnet20.network_surgery.yaml
index ba8740e..1188ec8 100755
--- a/examples/network_surgery/resnet20.network_surgery.yaml
+++ b/examples/network_surgery/resnet20.network_surgery.yaml
@@ -31,7 +31,7 @@
 #     Total sparsity: 69.1%
 #     # of parameters: 83,671
 #
-# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.01 --epochs=180 --compress=../network_surgery/resnet20.network_surgery.yaml -j=1 --deterministic  --validation-size=0 --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --masks-sparsity --num-best-scores=10
+# time python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.01 --epochs=180 --compress=../network_surgery/resnet20.network_surgery.yaml -j=1 --deterministic  --validation-split=0 --resume=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --masks-sparsity --num-best-scores=10
 #
 # Parameters:
 # +----+-------------------------------------+----------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
diff --git a/examples/network_surgery/resnet50.network_surgery.yaml b/examples/network_surgery/resnet50.network_surgery.yaml
index d7d0588..9822701 100755
--- a/examples/network_surgery/resnet50.network_surgery.yaml
+++ b/examples/network_surgery/resnet50.network_surgery.yaml
@@ -5,7 +5,7 @@
 # Top1 is 75.492 (on Epoch: 93) vs the published Top1: 76.15 (https://pytorch.org/docs/stable/torchvision/models.html)
 # Total sparsity: 80.05
 #
-# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.network_surgery.yaml --validation-size=0  --masks-sparsity --num-best-scores=10
+# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.network_surgery.yaml --validation-split=0  --masks-sparsity --num-best-scores=10
 #
 # Parameters:
 # +----+-------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
diff --git a/examples/network_surgery/resnet50.network_surgery2.yaml b/examples/network_surgery/resnet50.network_surgery2.yaml
index 17612b9..2f50b58 100755
--- a/examples/network_surgery/resnet50.network_surgery2.yaml
+++ b/examples/network_surgery/resnet50.network_surgery2.yaml
@@ -5,7 +5,7 @@
 # Top1 is 75.518 (on Epoch: 99) vs the published Top1: 76.15 (https://pytorch.org/docs/stable/torchvision/models.html)
 # Total sparsity: 82.6%
 #
-# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.001 --compress=resnet50.network_surgery2.yaml --validation-size=0  --masks-sparsity --num-best-scores=10
+# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.001 --compress=resnet50.network_surgery2.yaml --validation-split=0  --masks-sparsity --num-best-scores=10
 #
 #
 # Parameters:
diff --git a/examples/network_trimming/resnet50.filters.activation_apoz_agp.yaml b/examples/network_trimming/resnet50.filters.activation_apoz_agp.yaml
index 2a02ef5..ebced2a 100755
--- a/examples/network_trimming/resnet50.filters.activation_apoz_agp.yaml
+++ b/examples/network_trimming/resnet50.filters.activation_apoz_agp.yaml
@@ -1,7 +1,7 @@
 #
 # This schedule uses the average percentage of zeros (APoZ) in the activations, to rank filters.
 #
-# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters.yaml --validation-size=0   --num-best-scores=10 --name="resnet50_filters_v5_APoZ"   --act-stats=valid
+# time python3 compress_classifier.py -a=resnet50 --pretrained -p=50 ../../../data.imagenet/ -j=22 --epochs=100 --lr=0.0005 --compress=resnet50.schedule_agp.filters.yaml --validation-split=0   --num-best-scores=10 --name="resnet50_filters_v5_APoZ"   --act-stats=valid
 #
 # Results:
 #   Best Top1: 73.926 on Epoch: 88
diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py
index c78d0fc..e361a61 100755
--- a/tests/full_flow_tests.py
+++ b/tests/full_flow_tests.py
@@ -118,13 +118,13 @@ test_configs = [
     TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [48.340, 92.630]),
     TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate'.
                format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')),
-               DS_CIFAR, accuracy_checker, [91.580, 99.620]),
+               DS_CIFAR, accuracy_checker, [91.640, 99.610]),
     TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'.
                format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')),
                DS_CIFAR, accuracy_checker, [48.290, 94.460]),
     TestConfig('-a resnet20_cifar --resume {0} --sense=filter --sense-range 0 0.10 0.05'.
                format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')),
-               DS_CIFAR, collateral_checker, [('sensitivity.csv', 3188), ('sensitivity.png', 96158)])
+               DS_CIFAR, collateral_checker, [('sensitivity.csv', 3165), ('sensitivity.png', 96158)])
 ]
 
 
-- 
GitLab