From 697b3cfee5b23b205cc843377ce1a054c6bed0c4 Mon Sep 17 00:00:00 2001
From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com>
Date: Mon, 2 Dec 2019 16:38:59 +0200
Subject: [PATCH] Object Detection Compression (#343)

Add an example of compressing OD pytorch models.

In this example we compress torchvision's object detection models - FasterRCNN / MaskRCNN / KeypointRCNN.
We've modified the reference code for object detection to allow easy compression scheduling with YAML configuration.
---
 distiller/model_transforms.py                 |   3 +-
 distiller/models/__init__.py                  |   4 +-
 distiller/quantization/sim_bn_fold.py         |  21 +-
 .../object_detection_compression/README.md    |  57 +++
 .../object_detection_compression/__init__.py  |   0
 .../object_detection_compression/coco_eval.py | 351 +++++++++++++++++
 .../coco_utils.py                             | 253 +++++++++++++
 .../compress_detector.py                      | 356 ++++++++++++++++++
 .../data/download_dataset.sh                  |   9 +
 .../object_detection_compression/engine.py    | 125 ++++++
 .../group_by_aspect_ratio.py                  | 189 ++++++++++
 .../object_detection_compression/logging.conf |  54 +++
 .../maskrcnn.scheduler_agp.non_parallel.yaml  | 210 +++++++++++
 .../maskrcnn.scheduler_agp.yaml               | 267 +++++++++++++
 .../requirements.txt                          |   2 +
 .../transforms.py                             |  52 +++
 .../object_detection_compression/utils.py     | 328 ++++++++++++++++
 17 files changed, 2274 insertions(+), 7 deletions(-)
 create mode 100644 examples/object_detection_compression/README.md
 create mode 100644 examples/object_detection_compression/__init__.py
 create mode 100644 examples/object_detection_compression/coco_eval.py
 create mode 100644 examples/object_detection_compression/coco_utils.py
 create mode 100644 examples/object_detection_compression/compress_detector.py
 create mode 100644 examples/object_detection_compression/data/download_dataset.sh
 create mode 100644 examples/object_detection_compression/engine.py
 create mode 100644 examples/object_detection_compression/group_by_aspect_ratio.py
 create mode 100644 examples/object_detection_compression/logging.conf
 create mode 100644 examples/object_detection_compression/maskrcnn.scheduler_agp.non_parallel.yaml
 create mode 100644 examples/object_detection_compression/maskrcnn.scheduler_agp.yaml
 create mode 100644 examples/object_detection_compression/requirements.txt
 create mode 100644 examples/object_detection_compression/transforms.py
 create mode 100644 examples/object_detection_compression/utils.py

diff --git a/distiller/model_transforms.py b/distiller/model_transforms.py
index 5761431..647d2ed 100644
--- a/distiller/model_transforms.py
+++ b/distiller/model_transforms.py
@@ -16,6 +16,7 @@
 
 import torch
 import torch.nn as nn
+from torchvision.ops.misc import FrozenBatchNorm2d
 from collections import OrderedDict
 import distiller
 import distiller.modules
@@ -129,7 +130,7 @@ def fold_batch_norms(model, dummy_input=None, adjacency_map=None, inference=True
         return folded_module
 
     foldables = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)
-    batchnorms = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
+    batchnorms = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, FrozenBatchNorm2d)
     return fuse_modules(model, (foldables, batchnorms), fold_bn, dummy_input, adjacency_map)
 
 
diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py
index 8f94bc9..366edd4 100755
--- a/distiller/models/__init__.py
+++ b/distiller/models/__init__.py
@@ -20,13 +20,15 @@ import copy
 from functools import partial
 import torch
 import torchvision.models as torch_models
+from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN
+from torchvision.ops.misc import FrozenBatchNorm2d
 import torch.nn as nn
 from . import cifar10 as cifar10_models
 from . import mnist as mnist_models
 from . import imagenet as imagenet_extra_models
 import pretrainedmodels
 
-from distiller.utils import set_model_input_shape_attr
+from distiller.utils import set_model_input_shape_attr, model_setattr
 from distiller.modules import Mean, EltwiseAdd
 
 import logging
diff --git a/distiller/quantization/sim_bn_fold.py b/distiller/quantization/sim_bn_fold.py
index 8c3a6f3..5fa17ea 100644
--- a/distiller/quantization/sim_bn_fold.py
+++ b/distiller/quantization/sim_bn_fold.py
@@ -1,6 +1,7 @@
 import torch
 import torch.nn as nn
 from torch.nn import functional as F
+from torchvision.ops.misc import FrozenBatchNorm2d
 
 __all__ = ['SimulatedFoldedBatchNorm']
 
@@ -28,22 +29,32 @@ class SimulatedFoldedBatchNorm(nn.Module):
         Wrapper for simulated folding of BatchNorm into convolution / linear layers during training
         Args:
             param_module (nn.Linear or nn.Conv1d or nn.Conv2d or nn.Conv3d): the wrapped parameter module
-            bn (nn.BatchNorm1d or nn.BatchNorm2d or nn.BatchNorm3d): batch normalization module
+            bn (nn.BatchNorm1d or nn.BatchNorm2d or nn.BatchNorm3d or FrozenBatchNorm2d): batch normalization module
             freeze_bn_delay (int): number of steps before freezing the batch-norm running stats
             param_quantization_fn (function): function to be used for weight/bias quantization
         Note:
             The quantized version was implemented according to https://arxiv.org/pdf/1806.08342.pdf Section 3.2.2.
         """
-        SimulatedFoldedBatchNorm.verify_module_types(param_module, bn)
-        if not bn.track_running_stats:
-            raise ValueError("Simulated BN folding is only supported for BatchNorm which tracks running stats")
-        super(SimulatedFoldedBatchNorm, self).__init__()
         self.param_module = param_module
         self.bn = bn
         self.freeze_bn_delay = freeze_bn_delay
         self.frozen = False
         self._has_bias = (self.param_module.bias is not None)
         self.param_quant_fn = param_quantization_fn
+        if isinstance(bn, FrozenBatchNorm2d):
+            if not isinstance(param_module, nn.Conv2d):
+                error_msg = "Can't fold sequence of {} --> {}. ".format(
+                    param_module.__class__.__name__, bn.__class__.__name__
+                )
+                raise TypeError(error_msg + ' FrozenBatchNorm2d must follow a nn.Conv2d.')
+            # This torchvision op is frozen from the beginning, so we fuse it
+            # directly into the linear layer.
+            self.freeze()
+            return
+        SimulatedFoldedBatchNorm.verify_module_types(param_module, bn)
+        if not bn.track_running_stats:
+            raise ValueError("Simulated BN folding is only supported for BatchNorm which tracks running stats")
+        super(SimulatedFoldedBatchNorm, self).__init__()
         if isinstance(param_module, nn.Linear):
             self.param_forward_fn = self._linear_layer_forward
             self.param_module_type = "fc"
diff --git a/examples/object_detection_compression/README.md b/examples/object_detection_compression/README.md
new file mode 100644
index 0000000..9ac368a
--- /dev/null
+++ b/examples/object_detection_compression/README.md
@@ -0,0 +1,57 @@
+# Object Detection Compression
+
+In this example we compress torchvision's object detection models - FasterRCNN / MaskRCNN / KeypointRCNN.  
+We've modified the [reference code for object detection](https://github.com/pytorch/vision/tree/master/references/detection)
+to allow easy compression scheduling with yaml configuration.
+
+## Setup
+Install the dependencies 
+(most of which are already installed from Distiller dependies, the rest are `Cython` and `pycocotools`):
+    
+    cd <distiller root>/examples/object_detection_compression/
+    pip3 install -r requirements.txt 
+
+The dataset can be downloaded at the [COCO dataset website](http://cocodataset.org/#download).
+Please keep in mind that COCO dataset takes up 18 GB of storage.  
+In this example we'll use the 2017 training+validation sets, which you can download using the command line:
+
+    cd data
+    bash download_dataset.sh
+
+## Running the Example
+The command line for running this example is closely related to 
+[`compress_classifier.py`](../classifier_compression/compress_classifier.py), i.e. the
+compression scheduler format and most of the Distiller related arguments are the same.  
+However - running in a multi-GPU environment is different from `compress_classifier.py`, because this script is a modified
+[`train.py` from torchvision references](https://github.com/pytorch/vision/tree/master/references/detection/train.py), 
+where they used `torch.distributed.launch` for multi-GPU (and multi-node in general) training.
+
+**Note-** Use of `torch.distributed.launch` will spawn multiple processes, on each process
+there will be a copy of the model and the weights, each of the models is an instance of 
+[`torch.nn.parallel.DistributedDataParallel`](https://pytorch.org/docs/stable/nn.html#distributeddataparallel).
+During backward pass, the gradients from each node are averaged and then passed to all nodes,
+thus promising the weights on the nodes are the same. 
+This also promises that the pruning masks remain identical on all the nodes.
+ 
+ Example Single GPU Command Line - 
+ 
+    python compress_detector.py --data-path /path/to/dataset.COCO --pretrained --compress maskrcnn.scheduler_agp.non_parallel.yaml
+
+ Example Multi GPU Command Line -  
+ 
+    python -m torch.distributed.launch --nproc_per_node=4 --use_env compress_detector.py --data-path /path/to/dataset.COCO \
+     --compress maskrcnn.scheduler_agp.yaml --pretrained --world-size 4 --batch-size 2 --epochs 80
+
+Since the dataset is large and FasterRCNN models are compute heavy, we strongly recommend
+running the script on a Multi GPU environment. Keep in mind that the multi-GPU case is 
+running on multiple processes via `torch.distributed.launch`, and ending one of the processes
+might break all of them and leave them in an undefined state (In that case you'll have to end 
+them manually). Also, even though the multi-GPU distributes the memory over all the GPUs, the 
+model is quite memory intensive, so using a large batch size is guaranteed to yield OOM on the GPU. 
+Our GPUs are TITAN X (Pascal) with 12GB memory, and a batch size of 3 is the most we could run without 
+memory errors. 
+
+
+The default model is `torchvision.models.detection.maskrcnn_resnet50_fpn`, you can specify 
+any model that is part of `torchvision.models.detection` using
+ the `--model` argument, e.g. `--model maskrcnn_resnet50_fpn`.
\ No newline at end of file
diff --git a/examples/object_detection_compression/__init__.py b/examples/object_detection_compression/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/examples/object_detection_compression/coco_eval.py b/examples/object_detection_compression/coco_eval.py
new file mode 100644
index 0000000..5b68514
--- /dev/null
+++ b/examples/object_detection_compression/coco_eval.py
@@ -0,0 +1,351 @@
+# This code is originally from:
+#   https://github.com/pytorch/vision/tree/v0.4.2/references/detection
+import json
+import tempfile
+
+import numpy as np
+import copy
+import time
+import torch
+import torch._six
+
+from pycocotools.cocoeval import COCOeval
+from pycocotools.coco import COCO
+import pycocotools.mask as mask_util
+
+from collections import defaultdict
+
+import utils
+
+
+class CocoEvaluator(object):
+    def __init__(self, coco_gt, iou_types):
+        assert isinstance(iou_types, (list, tuple))
+        coco_gt = copy.deepcopy(coco_gt)
+        self.coco_gt = coco_gt
+
+        self.iou_types = iou_types
+        self.coco_eval = {}
+        for iou_type in iou_types:
+            self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
+
+        self.img_ids = []
+        self.eval_imgs = {k: [] for k in iou_types}
+
+    def update(self, predictions):
+        img_ids = list(np.unique(list(predictions.keys())))
+        self.img_ids.extend(img_ids)
+
+        for iou_type in self.iou_types:
+            results = self.prepare(predictions, iou_type)
+            coco_dt = loadRes(self.coco_gt, results) if results else COCO()
+            coco_eval = self.coco_eval[iou_type]
+
+            coco_eval.cocoDt = coco_dt
+            coco_eval.params.imgIds = list(img_ids)
+            img_ids, eval_imgs = evaluate(coco_eval)
+
+            self.eval_imgs[iou_type].append(eval_imgs)
+
+    def synchronize_between_processes(self):
+        for iou_type in self.iou_types:
+            self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
+            create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
+
+    def accumulate(self):
+        for coco_eval in self.coco_eval.values():
+            coco_eval.accumulate()
+
+    def summarize(self):
+        for iou_type, coco_eval in self.coco_eval.items():
+            print("IoU metric: {}".format(iou_type))
+            coco_eval.summarize()
+
+    def prepare(self, predictions, iou_type):
+        if iou_type == "bbox":
+            return self.prepare_for_coco_detection(predictions)
+        elif iou_type == "segm":
+            return self.prepare_for_coco_segmentation(predictions)
+        elif iou_type == "keypoints":
+            return self.prepare_for_coco_keypoint(predictions)
+        else:
+            raise ValueError("Unknown iou type {}".format(iou_type))
+
+    def prepare_for_coco_detection(self, predictions):
+        coco_results = []
+        for original_id, prediction in predictions.items():
+            if len(prediction) == 0:
+                continue
+
+            boxes = prediction["boxes"]
+            boxes = convert_to_xywh(boxes).tolist()
+            scores = prediction["scores"].tolist()
+            labels = prediction["labels"].tolist()
+
+            coco_results.extend(
+                [
+                    {
+                        "image_id": original_id,
+                        "category_id": labels[k],
+                        "bbox": box,
+                        "score": scores[k],
+                    }
+                    for k, box in enumerate(boxes)
+                ]
+            )
+        return coco_results
+
+    def prepare_for_coco_segmentation(self, predictions):
+        coco_results = []
+        for original_id, prediction in predictions.items():
+            if len(prediction) == 0:
+                continue
+
+            scores = prediction["scores"]
+            labels = prediction["labels"]
+            masks = prediction["masks"]
+
+            masks = masks > 0.5
+
+            scores = prediction["scores"].tolist()
+            labels = prediction["labels"].tolist()
+
+            rles = [
+                mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
+                for mask in masks
+            ]
+            for rle in rles:
+                rle["counts"] = rle["counts"].decode("utf-8")
+
+            coco_results.extend(
+                [
+                    {
+                        "image_id": original_id,
+                        "category_id": labels[k],
+                        "segmentation": rle,
+                        "score": scores[k],
+                    }
+                    for k, rle in enumerate(rles)
+                ]
+            )
+        return coco_results
+
+    def prepare_for_coco_keypoint(self, predictions):
+        coco_results = []
+        for original_id, prediction in predictions.items():
+            if len(prediction) == 0:
+                continue
+
+            boxes = prediction["boxes"]
+            boxes = convert_to_xywh(boxes).tolist()
+            scores = prediction["scores"].tolist()
+            labels = prediction["labels"].tolist()
+            keypoints = prediction["keypoints"]
+            keypoints = keypoints.flatten(start_dim=1).tolist()
+
+            coco_results.extend(
+                [
+                    {
+                        "image_id": original_id,
+                        "category_id": labels[k],
+                        'keypoints': keypoint,
+                        "score": scores[k],
+                    }
+                    for k, keypoint in enumerate(keypoints)
+                ]
+            )
+        return coco_results
+
+
+def convert_to_xywh(boxes):
+    xmin, ymin, xmax, ymax = boxes.unbind(1)
+    return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
+
+
+def merge(img_ids, eval_imgs):
+    all_img_ids = utils.all_gather(img_ids)
+    all_eval_imgs = utils.all_gather(eval_imgs)
+
+    merged_img_ids = []
+    for p in all_img_ids:
+        merged_img_ids.extend(p)
+
+    merged_eval_imgs = []
+    for p in all_eval_imgs:
+        merged_eval_imgs.append(p)
+
+    merged_img_ids = np.array(merged_img_ids)
+    merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
+
+    # keep only unique (and in sorted order) images
+    merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
+    merged_eval_imgs = merged_eval_imgs[..., idx]
+
+    return merged_img_ids, merged_eval_imgs
+
+
+def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
+    img_ids, eval_imgs = merge(img_ids, eval_imgs)
+    img_ids = list(img_ids)
+    eval_imgs = list(eval_imgs.flatten())
+
+    coco_eval.evalImgs = eval_imgs
+    coco_eval.params.imgIds = img_ids
+    coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
+
+
+#################################################################
+# From pycocotools, just removed the prints and fixed
+# a Python3 bug about unicode not defined
+#################################################################
+
+# Ideally, pycocotools wouldn't have hard-coded prints
+# so that we could avoid copy-pasting those two functions
+
+def createIndex(self):
+    # create index
+    # print('creating index...')
+    anns, cats, imgs = {}, {}, {}
+    imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
+    if 'annotations' in self.dataset:
+        for ann in self.dataset['annotations']:
+            imgToAnns[ann['image_id']].append(ann)
+            anns[ann['id']] = ann
+
+    if 'images' in self.dataset:
+        for img in self.dataset['images']:
+            imgs[img['id']] = img
+
+    if 'categories' in self.dataset:
+        for cat in self.dataset['categories']:
+            cats[cat['id']] = cat
+
+    if 'annotations' in self.dataset and 'categories' in self.dataset:
+        for ann in self.dataset['annotations']:
+            catToImgs[ann['category_id']].append(ann['image_id'])
+
+    # print('index created!')
+
+    # create class members
+    self.anns = anns
+    self.imgToAnns = imgToAnns
+    self.catToImgs = catToImgs
+    self.imgs = imgs
+    self.cats = cats
+
+
+maskUtils = mask_util
+
+
+def loadRes(self, resFile):
+    """
+    Load result file and return a result api object.
+    :param   resFile (str)     : file name of result file
+    :return: res (obj)         : result api object
+    """
+    res = COCO()
+    res.dataset['images'] = [img for img in self.dataset['images']]
+
+    # print('Loading and preparing results...')
+    # tic = time.time()
+    if isinstance(resFile, torch._six.string_classes):
+        anns = json.load(open(resFile))
+    elif type(resFile) == np.ndarray:
+        anns = self.loadNumpyAnnotations(resFile)
+    else:
+        anns = resFile
+    assert type(anns) == list, 'results in not an array of objects'
+    annsImgIds = [ann['image_id'] for ann in anns]
+    assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
+        'Results do not correspond to current coco set'
+    if 'caption' in anns[0]:
+        imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
+        res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
+        for id, ann in enumerate(anns):
+            ann['id'] = id + 1
+    elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
+        res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
+        for id, ann in enumerate(anns):
+            bb = ann['bbox']
+            x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
+            if 'segmentation' not in ann:
+                ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
+            ann['area'] = bb[2] * bb[3]
+            ann['id'] = id + 1
+            ann['iscrowd'] = 0
+    elif 'segmentation' in anns[0]:
+        res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
+        for id, ann in enumerate(anns):
+            # now only support compressed RLE format as segmentation results
+            ann['area'] = maskUtils.area(ann['segmentation'])
+            if 'bbox' not in ann:
+                ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
+            ann['id'] = id + 1
+            ann['iscrowd'] = 0
+    elif 'keypoints' in anns[0]:
+        res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
+        for id, ann in enumerate(anns):
+            s = ann['keypoints']
+            x = s[0::3]
+            y = s[1::3]
+            x1, x2, y1, y2 = np.min(x), np.max(x), np.min(y), np.max(y)
+            ann['area'] = (x2 - x1) * (y2 - y1)
+            ann['id'] = id + 1
+            ann['bbox'] = [x1, y1, x2 - x1, y2 - y1]
+    # print('DONE (t={:0.2f}s)'.format(time.time()- tic))
+
+    res.dataset['annotations'] = anns
+    createIndex(res)
+    return res
+
+
+def evaluate(self):
+    '''
+    Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
+    :return: None
+    '''
+    # tic = time.time()
+    # print('Running per image evaluation...')
+    p = self.params
+    # add backward compatibility if useSegm is specified in params
+    if p.useSegm is not None:
+        p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
+        print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
+    # print('Evaluate annotation type *{}*'.format(p.iouType))
+    p.imgIds = list(np.unique(p.imgIds))
+    if p.useCats:
+        p.catIds = list(np.unique(p.catIds))
+    p.maxDets = sorted(p.maxDets)
+    self.params = p
+
+    self._prepare()
+    # loop through images, area range, max detection number
+    catIds = p.catIds if p.useCats else [-1]
+
+    if p.iouType == 'segm' or p.iouType == 'bbox':
+        computeIoU = self.computeIoU
+    elif p.iouType == 'keypoints':
+        computeIoU = self.computeOks
+    self.ious = {
+        (imgId, catId): computeIoU(imgId, catId)
+        for imgId in p.imgIds
+        for catId in catIds}
+
+    evaluateImg = self.evaluateImg
+    maxDet = p.maxDets[-1]
+    evalImgs = [
+        evaluateImg(imgId, catId, areaRng, maxDet)
+        for catId in catIds
+        for areaRng in p.areaRng
+        for imgId in p.imgIds
+    ]
+    # this is NOT in the pycocotools code, but could be done outside
+    evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
+    self._paramsEval = copy.deepcopy(self.params)
+    # toc = time.time()
+    # print('DONE (t={:0.2f}s).'.format(toc-tic))
+    return p.imgIds, evalImgs
+
+#################################################################
+# end of straight copy from pycocotools, just removing the prints
+#################################################################
diff --git a/examples/object_detection_compression/coco_utils.py b/examples/object_detection_compression/coco_utils.py
new file mode 100644
index 0000000..b9bb5e8
--- /dev/null
+++ b/examples/object_detection_compression/coco_utils.py
@@ -0,0 +1,253 @@
+# This code is originally from:
+#   https://github.com/pytorch/vision/tree/v0.4.2/references/detection
+import copy
+import os
+from PIL import Image
+
+import torch
+import torch.utils.data
+import torchvision
+
+from pycocotools import mask as coco_mask
+from pycocotools.coco import COCO
+
+import transforms as T
+
+
+class FilterAndRemapCocoCategories(object):
+    def __init__(self, categories, remap=True):
+        self.categories = categories
+        self.remap = remap
+
+    def __call__(self, image, target):
+        anno = target["annotations"]
+        anno = [obj for obj in anno if obj["category_id"] in self.categories]
+        if not self.remap:
+            target["annotations"] = anno
+            return image, target
+        anno = copy.deepcopy(anno)
+        for obj in anno:
+            obj["category_id"] = self.categories.index(obj["category_id"])
+        target["annotations"] = anno
+        return image, target
+
+
+def convert_coco_poly_to_mask(segmentations, height, width):
+    masks = []
+    for polygons in segmentations:
+        rles = coco_mask.frPyObjects(polygons, height, width)
+        mask = coco_mask.decode(rles)
+        if len(mask.shape) < 3:
+            mask = mask[..., None]
+        mask = torch.as_tensor(mask, dtype=torch.uint8)
+        mask = mask.any(dim=2)
+        masks.append(mask)
+    if masks:
+        masks = torch.stack(masks, dim=0)
+    else:
+        masks = torch.zeros((0, height, width), dtype=torch.uint8)
+    return masks
+
+
+class ConvertCocoPolysToMask(object):
+    def __call__(self, image, target):
+        w, h = image.size
+
+        image_id = target["image_id"]
+        image_id = torch.tensor([image_id])
+
+        anno = target["annotations"]
+
+        anno = [obj for obj in anno if obj['iscrowd'] == 0]
+
+        boxes = [obj["bbox"] for obj in anno]
+        # guard against no boxes via resizing
+        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
+        boxes[:, 2:] += boxes[:, :2]
+        boxes[:, 0::2].clamp_(min=0, max=w)
+        boxes[:, 1::2].clamp_(min=0, max=h)
+
+        classes = [obj["category_id"] for obj in anno]
+        classes = torch.tensor(classes, dtype=torch.int64)
+
+        segmentations = [obj["segmentation"] for obj in anno]
+        masks = convert_coco_poly_to_mask(segmentations, h, w)
+
+        keypoints = None
+        if anno and "keypoints" in anno[0]:
+            keypoints = [obj["keypoints"] for obj in anno]
+            keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
+            num_keypoints = keypoints.shape[0]
+            if num_keypoints:
+                keypoints = keypoints.view(num_keypoints, -1, 3)
+
+        keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
+        boxes = boxes[keep]
+        classes = classes[keep]
+        masks = masks[keep]
+        if keypoints is not None:
+            keypoints = keypoints[keep]
+
+        target = {}
+        target["boxes"] = boxes
+        target["labels"] = classes
+        target["masks"] = masks
+        target["image_id"] = image_id
+        if keypoints is not None:
+            target["keypoints"] = keypoints
+
+        # for conversion to coco api
+        area = torch.tensor([obj["area"] for obj in anno])
+        iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
+        target["area"] = area
+        target["iscrowd"] = iscrowd
+
+        return image, target
+
+
+def _coco_remove_images_without_annotations(dataset, cat_list=None):
+    def _has_only_empty_bbox(anno):
+        return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
+
+    def _count_visible_keypoints(anno):
+        return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
+
+    min_keypoints_per_image = 10
+
+    def _has_valid_annotation(anno):
+        # if it's empty, there is no annotation
+        if len(anno) == 0:
+            return False
+        # if all boxes have close to zero area, there is no annotation
+        if _has_only_empty_bbox(anno):
+            return False
+        # keypoints task have a slight different critera for considering
+        # if an annotation is valid
+        if "keypoints" not in anno[0]:
+            return True
+        # for keypoint detection tasks, only consider valid images those
+        # containing at least min_keypoints_per_image
+        if _count_visible_keypoints(anno) >= min_keypoints_per_image:
+            return True
+        return False
+
+    assert isinstance(dataset, torchvision.datasets.CocoDetection)
+    ids = []
+    for ds_idx, img_id in enumerate(dataset.ids):
+        ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
+        anno = dataset.coco.loadAnns(ann_ids)
+        if cat_list:
+            anno = [obj for obj in anno if obj["category_id"] in cat_list]
+        if _has_valid_annotation(anno):
+            ids.append(ds_idx)
+
+    dataset = torch.utils.data.Subset(dataset, ids)
+    return dataset
+
+
+def convert_to_coco_api(ds):
+    coco_ds = COCO()
+    ann_id = 0
+    dataset = {'images': [], 'categories': [], 'annotations': []}
+    categories = set()
+    for img_idx in range(len(ds)):
+        # find better way to get target
+        # targets = ds.get_annotations(img_idx)
+        img, targets = ds[img_idx]
+        image_id = targets["image_id"].item()
+        img_dict = {}
+        img_dict['id'] = image_id
+        img_dict['height'] = img.shape[-2]
+        img_dict['width'] = img.shape[-1]
+        dataset['images'].append(img_dict)
+        bboxes = targets["boxes"]
+        bboxes[:, 2:] -= bboxes[:, :2]
+        bboxes = bboxes.tolist()
+        labels = targets['labels'].tolist()
+        areas = targets['area'].tolist()
+        iscrowd = targets['iscrowd'].tolist()
+        if 'masks' in targets:
+            masks = targets['masks']
+            # make masks Fortran contiguous for coco_mask
+            masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
+        if 'keypoints' in targets:
+            keypoints = targets['keypoints']
+            keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
+        num_objs = len(bboxes)
+        for i in range(num_objs):
+            ann = {}
+            ann['image_id'] = image_id
+            ann['bbox'] = bboxes[i]
+            ann['category_id'] = labels[i]
+            categories.add(labels[i])
+            ann['area'] = areas[i]
+            ann['iscrowd'] = iscrowd[i]
+            ann['id'] = ann_id
+            if 'masks' in targets:
+                ann["segmentation"] = coco_mask.encode(masks[i].numpy())
+            if 'keypoints' in targets:
+                ann['keypoints'] = keypoints[i]
+                ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3])
+            dataset['annotations'].append(ann)
+            ann_id += 1
+    dataset['categories'] = [{'id': i} for i in sorted(categories)]
+    coco_ds.dataset = dataset
+    coco_ds.createIndex()
+    return coco_ds
+
+
+def get_coco_api_from_dataset(dataset):
+    for _ in range(10):
+        if isinstance(dataset, torchvision.datasets.CocoDetection):
+            break
+        if isinstance(dataset, torch.utils.data.Subset):
+            dataset = dataset.dataset
+    if isinstance(dataset, torchvision.datasets.CocoDetection):
+        return dataset.coco
+    return convert_to_coco_api(dataset)
+
+
+class CocoDetection(torchvision.datasets.CocoDetection):
+    def __init__(self, img_folder, ann_file, transforms):
+        super(CocoDetection, self).__init__(img_folder, ann_file)
+        self._transforms = transforms
+
+    def __getitem__(self, idx):
+        img, target = super(CocoDetection, self).__getitem__(idx)
+        image_id = self.ids[idx]
+        target = dict(image_id=image_id, annotations=target)
+        if self._transforms is not None:
+            img, target = self._transforms(img, target)
+        return img, target
+
+
+def get_coco(root, image_set, transforms, mode='instances'):
+    anno_file_template = "{}_{}2017.json"
+    PATHS = {
+        "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
+        "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
+        # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
+    }
+
+    t = [ConvertCocoPolysToMask()]
+
+    if transforms is not None:
+        t.append(transforms)
+    transforms = T.Compose(t)
+
+    img_folder, ann_file = PATHS[image_set]
+    img_folder = os.path.join(root, img_folder)
+    ann_file = os.path.join(root, ann_file)
+
+    dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
+
+    if image_set == "train":
+        dataset = _coco_remove_images_without_annotations(dataset)
+
+    # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
+
+    return dataset
+
+
+def get_coco_kp(root, image_set, transforms):
+    return get_coco(root, image_set, transforms, mode="person_keypoints")
diff --git a/examples/object_detection_compression/compress_detector.py b/examples/object_detection_compression/compress_detector.py
new file mode 100644
index 0000000..4f477c3
--- /dev/null
+++ b/examples/object_detection_compression/compress_detector.py
@@ -0,0 +1,356 @@
+# This code is originally from:
+#   https://github.com/pytorch/vision/tree/v0.4.2/references/detection/train.py
+# It contains code to support compression (distiller)
+r"""PyTorch Detection Training.
+
+To run in a multi-gpu environment, use the distributed launcher::
+
+    python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env \
+        compress_detector.py ... --world-size $NGPU
+
+"""
+import datetime
+import os
+import time
+
+import torch
+import torch.utils.data
+from torch import nn
+import torchvision
+import torchvision.models.detection as detection
+from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN
+from torchvision.ops.misc import FrozenBatchNorm2d
+import torch.distributed as dist
+
+import distiller
+from distiller.data_loggers import *
+import distiller.apputils as apputils
+import distiller.pruning
+import distiller.models
+from distiller.model_transforms import fold_batch_norms
+
+from coco_utils import get_coco, get_coco_kp
+
+from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
+from engine import train_one_epoch, evaluate
+
+import utils
+import transforms as T
+
+import logging
+logging.getLogger().setLevel(logging.INFO)  # Allow distiller info to be logged.
+
+
+def get_dataset(name, image_set, transform, data_path):
+    paths = {
+        "coco": (data_path, get_coco, 91),
+        "coco_kp": (data_path, get_coco_kp, 2)
+    }
+    p, ds_fn, num_classes = paths[name]
+
+    ds = ds_fn(p, image_set=image_set, transforms=transform)
+    return ds, num_classes
+
+
+def get_transform(train):
+    transforms = []
+    transforms.append(T.ToTensor())
+    if train:
+        transforms.append(T.RandomHorizontalFlip(0.5))
+    return T.Compose(transforms)
+
+
+def patch_fastrcnn(model):
+    """
+    TODO - complete quantization example
+    Partial patch for torchvision's FastRCNN models to allow quantization, by replacing all FrozenBatchNorm2d
+    with regular nn.BatchNorm2d-s.
+    Args:
+        model (GeneralizedRCNN): the model to patch
+    """
+    assert isinstance(model, GeneralizedRCNN)
+
+    def replace_frozen_bn(frozen_bn: FrozenBatchNorm2d):
+        num_features = frozen_bn.weight.shape[0]
+        bn = nn.BatchNorm2d(num_features)
+        eps = bn.eps
+        bn.weight.data = frozen_bn.weight.data
+        bn.bias.data = frozen_bn.bias.data
+        bn.running_mean.data = frozen_bn.running_mean.data
+        bn.running_var.data = frozen_bn.running_var.data
+        return bn.eval()
+
+    for n, m in model.named_modules():
+        if isinstance(m, FrozenBatchNorm2d):
+            distiller.model_setattr(model, n, replace_frozen_bn(m))
+
+
+def main(args):
+    utils.init_distributed_mode(args)
+    print(args)
+
+    device = torch.device(args.device)
+
+    script_dir = os.path.dirname(__file__)
+    module_path = os.path.abspath(os.path.join(script_dir, '..', '..'))
+
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+    if utils.is_main_process():
+        msglogger = apputils.config_pylogger(os.path.join(script_dir, 'logging.conf'), args.name, args.output_dir,
+                                             args.verbose)
+
+        # Log various details about the execution environment.  It is sometimes useful
+        # to refer to past experiment executions and this information may be useful.
+        apputils.log_execution_env_state(
+            filter(None, [args.compress, args.qe_stats_file]),  # remove both None and empty strings
+            msglogger.logdir)
+        msglogger.debug("Distiller: %s", distiller.__version__)
+    else:
+        msglogger = logging.getLogger()
+        msglogger.disabled = True
+
+    # Data loading code
+    print("Loading data")
+    dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True), args.data_path)
+    dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False), args.data_path)
+
+    print("Creating data loaders")
+    if args.distributed:
+        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
+        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
+    else:
+        train_sampler = torch.utils.data.RandomSampler(dataset)
+        test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+
+    if args.aspect_ratio_group_factor >= 0:
+        group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
+        train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
+    else:
+        train_batch_sampler = torch.utils.data.BatchSampler(
+            train_sampler, args.batch_size, drop_last=True)
+
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
+        collate_fn=utils.collate_fn)
+
+    data_loader_test = torch.utils.data.DataLoader(
+        dataset_test, batch_size=1,
+        sampler=test_sampler, num_workers=args.workers,
+        collate_fn=utils.collate_fn)
+
+    print("Creating model")
+    model = detection.__dict__[args.model](num_classes=num_classes,
+                                                              pretrained=args.pretrained)
+    patch_fastrcnn(model)
+    model.to(device)
+
+    if args.summary and utils.is_main_process():
+        for summary in args.summary:
+            distiller.model_summary(model, summary, args.dataset)
+        return
+
+    model_without_ddp = model
+    if args.distributed:
+        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+        model_without_ddp = model.module
+
+    params = [p for p in model.parameters() if p.requires_grad]
+    optimizer = torch.optim.SGD(
+        params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
+
+    # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
+    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
+
+    compression_scheduler = None
+    if utils.is_main_process():
+        # Create a couple of logging backends.  TensorBoardLogger writes log files in a format
+        # that can be read by Google's Tensor Board.  PythonLogger writes to the Python logger.
+        tflogger = TensorBoardLogger(msglogger.logdir)
+        pylogger = PythonLogger(msglogger)
+
+    if args.compress:
+        # The main use-case for this sample application is CNN compression. Compression
+        # requires a compression schedule configuration file in YAML.
+        compression_scheduler = distiller.file_config(model, optimizer, args.compress, compression_scheduler, None)
+        # Model is re-transferred to GPU in case parameters were added (e.g. PACTQuantizer)
+        model.to(args.device)
+    elif compression_scheduler is None:
+        compression_scheduler = distiller.CompressionScheduler(model)
+
+    if args.qe_calibration:
+        def test_fn(model):
+            return evaluate(model, data_loader_test, device=device)
+        collect_quant_stats(model_without_ddp, test_fn, save_dir=args.output_dir,
+                            modules_to_collect=['backbone', 'rpn', 'roi_heads'])
+        # We skip `.transform` because it is a pre-processing unit.
+        return
+
+    if args.resume:
+        checkpoint = torch.load(args.resume, map_location='cpu')
+        model_without_ddp.load_state_dict(checkpoint['model'])
+        optimizer.load_state_dict(checkpoint['optimizer'])
+        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+        if compression_scheduler and 'compression_scheduler' in checkpoint:
+            compression_scheduler.load_state_dict(checkpoint['compression_scheduler'])
+
+    if args.test_only:
+        evaluate(model, data_loader_test, device=device)
+        return
+    activations_collectors = create_activation_stats_collectors(model, *args.activation_stats)
+    print("Start training")
+    start_time = time.time()
+
+    # if not isinstance(model, nn.DataParallel) and torch.cuda.is_available() \
+    #    and torch.cuda.device_count() > 1:
+    #     msglogger.info("Using %d GPUs on DataParallel." % torch.cuda.device_count())
+    #     model = nn.DataParallel(model)
+
+    for epoch in range(args.start_epoch, args.epochs):
+        if args.distributed:
+            train_sampler.set_epoch(epoch)
+            dist.barrier()
+
+        if compression_scheduler:
+            compression_scheduler.on_epoch_begin(epoch)
+
+        with collectors_context(activations_collectors["train"]) as collectors:
+            train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, compression_scheduler)
+            if utils.is_main_process():
+                distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])
+                distiller.log_activation_statsitics(epoch, "train", loggers=[tflogger],
+                                                    collector=collectors["sparsity"])
+            if args.masks_sparsity and utils.is_main_process():
+                msglogger.info(distiller.masks_sparsity_tbl_summary(model, compression_scheduler))
+
+        lr_scheduler.step()
+        if args.output_dir:
+            save_dict = {
+                'model': model_without_ddp.state_dict(),
+                'optimizer': optimizer.state_dict(),
+                'lr_scheduler': lr_scheduler.state_dict(),
+                'args': args}
+            if compression_scheduler:
+                save_dict['compression_scheduler'] = compression_scheduler.state_dict()
+            utils.save_on_master(save_dict,
+                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
+
+        # evaluate after every epoch
+        evaluate(model, data_loader_test, device=device)
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    print('Training time {}'.format(total_time_str))
+
+
+class missingdict(dict):
+    """This is a little trick to prevent KeyError"""
+    def __missing__(self, key):
+        return None  # note, does *not* set self[key] - we don't want defaultdict's behavior
+
+
+def create_activation_stats_collectors(model, *phases):
+    """Create objects that collect activation statistics.
+
+    This is a utility function that creates two collectors:
+    1. Fine-grade sparsity levels of the activations
+    2. L1-magnitude of each of the activation channels
+
+    Args:
+        model - the model on which we want to collect statistics
+        phases - the statistics collection phases: train, valid, and/or test
+
+    WARNING! Enabling activation statsitics collection will significantly slow down training!
+    """
+    genCollectors = lambda: missingdict({
+        "sparsity":      SummaryActivationStatsCollector(model, "sparsity",
+                                                         lambda t: 100 * distiller.utils.sparsity(t)),
+        "l1_channels":   SummaryActivationStatsCollector(model, "l1_channels",
+                                                         distiller.utils.activation_channels_l1),
+        "apoz_channels": SummaryActivationStatsCollector(model, "apoz_channels",
+                                                         distiller.utils.activation_channels_apoz),
+        "mean_channels": SummaryActivationStatsCollector(model, "mean_channels",
+                                                         distiller.utils.activation_channels_means),
+        "records":       RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])
+    })
+
+    return {k: (genCollectors() if k in phases else missingdict())
+            for k in ('train', 'valid', 'test')}
+
+
+def add_distiller_compression_args(parser):
+    SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params']
+    distiller_parser = parser.add_argument_group('Distiller related arguemnts')
+    distiller_parser.add_argument('--summary', type=lambda s: s.lower(), choices=SUMMARY_CHOICES, action='append',
+                        help='print a summary of the model, and exit - options: | '.join(SUMMARY_CHOICES))
+    distiller_parser.add_argument('--export-onnx', action='store', nargs='?', type=str, const='model.onnx',
+                                  default=None,
+                                  help='export model to ONNX format')
+    distiller_parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store',
+                                  help='configuration file for pruning the model '
+                                       '(default is to use hard-coded schedule)')
+    distiller.pruning.greedy_filter_pruning.add_greedy_pruner_args(distiller_parser)
+    distiller_parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name')
+    distiller_parser.add_argument('--verbose', '-v', action='store_true', help='Emit debug log messages')
+    distiller.quantization.add_post_train_quant_args(distiller_parser)
+    distiller_parser.add_argument('--activation-stats', '--act-stats', nargs='+', metavar='PHASE', default=list(),
+                                  help='collect activation statistics on phases: train, valid, and/or test'
+                                  ' (WARNING: this slows down training)')
+    distiller_parser.add_argument('--masks-sparsity', dest='masks_sparsity', action='store_true', default=False,
+                                  help='print masks sparsity table at end of each epoch')
+
+
+if __name__ == "__main__":
+    import argparse
+    parser = argparse.ArgumentParser(
+        description=__doc__)
+
+    parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset')
+    parser.add_argument('--dataset', default='coco', help='dataset')
+    parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model')
+    parser.add_argument('--device', default='cuda', help='device')
+    parser.add_argument('-b', '--batch-size', default=2, type=int)
+    parser.add_argument('--epochs', default=13, type=int, metavar='N',
+                        help='number of total epochs to run')
+    parser.add_argument('--start-epoch', default=0, type=int, help='starting epoch number')
+    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
+                        help='number of data loading workers (default: 16)')
+    parser.add_argument('--lr', default=0.02, type=float, help='initial learning rate')
+    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
+                        help='momentum')
+    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+                        metavar='W', help='weight decay (default: 1e-4)',
+                        dest='weight_decay')
+    parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
+    parser.add_argument('--lr-steps', default=[8, 11], nargs='+', type=int, help='decrease lr every step-size epochs')
+    parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
+    parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
+    parser.add_argument('--output-dir', default='.', help='path where to save')
+    parser.add_argument('--resume', default='', help='resume from checkpoint')
+    parser.add_argument('--aspect-ratio-group-factor', default=0, type=int)
+    parser.add_argument(
+        "--evaluate",
+        dest="test_only",
+        help="Only test the model",
+        action="store_true",
+    )
+    parser.add_argument(
+        "--pretrained",
+        dest="pretrained",
+        help="Use pre-trained models from the modelzoo",
+        action="store_true",
+    )
+
+    # distributed training parameters
+    parser.add_argument('--world-size', default=1, type=int,
+                        help='number of distributed processes')
+    parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
+
+    add_distiller_compression_args(parser)
+
+    args = parser.parse_args()
+
+    if args.output_dir:
+        utils.mkdir(args.output_dir)
+
+    main(args)
diff --git a/examples/object_detection_compression/data/download_dataset.sh b/examples/object_detection_compression/data/download_dataset.sh
new file mode 100644
index 0000000..60fdc6d
--- /dev/null
+++ b/examples/object_detection_compression/data/download_dataset.sh
@@ -0,0 +1,9 @@
+echo "Downloading Dataset..."
+wget http://images.cocodataset.org/zips/train2017.zip
+wget http://images.cocodataset.org/zips/val2017.zip
+wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
+
+echo "Extracting Dataset..."
+unzip train2017.zip
+unzip val2017.zip
+unzip annotations_trainval2017.zip
\ No newline at end of file
diff --git a/examples/object_detection_compression/engine.py b/examples/object_detection_compression/engine.py
new file mode 100644
index 0000000..ecc4c52
--- /dev/null
+++ b/examples/object_detection_compression/engine.py
@@ -0,0 +1,125 @@
+# This code is originally from:
+#   https://github.com/pytorch/vision/tree/v0.4.2/references/detection
+# (old commit)
+# It contains code to support compression (distiller)
+import math
+import sys
+import time
+import torch
+
+import torchvision.models.detection.mask_rcnn
+
+from coco_utils import get_coco_api_from_dataset
+from coco_eval import CocoEvaluator
+import utils
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, compression_scheduler=None):
+    model.train()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+    header = 'Epoch: [{}]'.format(epoch)
+    steps_per_epoch = len(data_loader)
+    lr_scheduler = None
+    if epoch == 0:
+        warmup_factor = 1. / 1000
+        warmup_iters = min(1000, len(data_loader) - 1)
+
+        lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
+
+    for train_step, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+        if compression_scheduler:
+            compression_scheduler.on_minibatch_begin(epoch, train_step, steps_per_epoch, optimizer)
+
+        images = list(image.to(device) for image in images)
+        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
+
+        loss_dict = model(images, targets)
+
+        losses = sum(loss for loss in loss_dict.values())
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = utils.reduce_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+        loss_value = losses_reduced.item()
+
+        if not math.isfinite(loss_value):
+            print("Loss is {}, stopping training".format(loss_value))
+            print(loss_dict_reduced)
+            sys.exit(1)
+        if compression_scheduler:
+            losses = compression_scheduler.before_backward_pass(epoch, train_step, steps_per_epoch, losses,
+                                                                optimizer=optimizer)
+
+        optimizer.zero_grad()
+        losses.backward()
+
+        if compression_scheduler:
+            compression_scheduler.before_parameter_optimization(epoch, train_step, steps_per_epoch, optimizer)
+        optimizer.step()
+
+        if compression_scheduler:
+            compression_scheduler.on_minibatch_end(epoch, train_step, steps_per_epoch, optimizer)
+
+        if lr_scheduler is not None:
+            lr_scheduler.step()
+
+        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+
+
+def _get_iou_types(model):
+    model_without_ddp = model
+    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
+        model_without_ddp = model.module
+    iou_types = ["bbox"]
+    if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
+        iou_types.append("segm")
+    if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
+        iou_types.append("keypoints")
+    return iou_types
+
+
+@torch.no_grad()
+def evaluate(model, data_loader, device):
+    n_threads = torch.get_num_threads()
+    # FIXME remove this and make paste_masks_in_image run on the GPU
+    torch.set_num_threads(1)
+    cpu_device = torch.device("cpu")
+    model.eval()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    header = 'Test:'
+
+    coco = get_coco_api_from_dataset(data_loader.dataset)
+    iou_types = _get_iou_types(model)
+    coco_evaluator = CocoEvaluator(coco, iou_types)
+
+    for image, targets in metric_logger.log_every(data_loader, 100, header):
+        image = list(img.to(device) for img in image)
+        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
+
+        torch.cuda.synchronize()
+        model_time = time.time()
+        outputs = model(image)
+
+        outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
+        model_time = time.time() - model_time
+
+        res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
+        evaluator_time = time.time()
+        coco_evaluator.update(res)
+        evaluator_time = time.time() - evaluator_time
+        metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
+
+    # gather the stats from all processes
+    metric_logger.synchronize_between_processes()
+    print("Averaged stats:", metric_logger)
+    coco_evaluator.synchronize_between_processes()
+
+    # accumulate predictions from all images
+    coco_evaluator.accumulate()
+    coco_evaluator.summarize()
+    torch.set_num_threads(n_threads)
+    return coco_evaluator
diff --git a/examples/object_detection_compression/group_by_aspect_ratio.py b/examples/object_detection_compression/group_by_aspect_ratio.py
new file mode 100644
index 0000000..b82a884
--- /dev/null
+++ b/examples/object_detection_compression/group_by_aspect_ratio.py
@@ -0,0 +1,189 @@
+# This code is originally from:
+#   https://github.com/pytorch/vision/tree/v0.4.2/references/detection
+import bisect
+from collections import defaultdict
+import copy
+import numpy as np
+
+import torch
+import torch.utils.data
+from torch.utils.data.sampler import BatchSampler, Sampler
+from torch.utils.model_zoo import tqdm
+import torchvision
+
+from PIL import Image
+
+
+class GroupedBatchSampler(BatchSampler):
+    """
+    Wraps another sampler to yield a mini-batch of indices.
+    It enforces that the batch only contain elements from the same group.
+    It also tries to provide mini-batches which follows an ordering which is
+    as close as possible to the ordering from the original sampler.
+    Arguments:
+        sampler (Sampler): Base sampler.
+        group_ids (list[int]): If the sampler produces indices in range [0, N),
+            `group_ids` must be a list of `N` ints which contains the group id of each sample.
+            The group ids must be a continuous set of integers starting from
+            0, i.e. they must be in the range [0, num_groups).
+        batch_size (int): Size of mini-batch.
+    """
+    def __init__(self, sampler, group_ids, batch_size):
+        if not isinstance(sampler, Sampler):
+            raise ValueError(
+                "sampler should be an instance of "
+                "torch.utils.data.Sampler, but got sampler={}".format(sampler)
+            )
+        self.sampler = sampler
+        self.group_ids = group_ids
+        self.batch_size = batch_size
+
+    def __iter__(self):
+        buffer_per_group = defaultdict(list)
+        samples_per_group = defaultdict(list)
+
+        num_batches = 0
+        for idx in self.sampler:
+            group_id = self.group_ids[idx]
+            buffer_per_group[group_id].append(idx)
+            samples_per_group[group_id].append(idx)
+            if len(buffer_per_group[group_id]) == self.batch_size:
+                yield buffer_per_group[group_id]
+                num_batches += 1
+                del buffer_per_group[group_id]
+            assert len(buffer_per_group[group_id]) < self.batch_size
+
+        # now we have run out of elements that satisfy
+        # the group criteria, let's return the remaining
+        # elements so that the size of the sampler is
+        # deterministic
+        expected_num_batches = len(self)
+        num_remaining = expected_num_batches - num_batches
+        if num_remaining > 0:
+            # for the remaining batches, take first the buffers with largest number
+            # of elements
+            for group_id, _ in sorted(buffer_per_group.items(),
+                                      key=lambda x: len(x[1]), reverse=True):
+                remaining = self.batch_size - len(buffer_per_group[group_id])
+                buffer_per_group[group_id].extend(
+                    samples_per_group[group_id][:remaining])
+                assert len(buffer_per_group[group_id]) == self.batch_size
+                yield buffer_per_group[group_id]
+                num_remaining -= 1
+                if num_remaining == 0:
+                    break
+        assert num_remaining == 0
+
+    def __len__(self):
+        return len(self.sampler) // self.batch_size
+
+
+def _compute_aspect_ratios_slow(dataset, indices=None):
+    print("Your dataset doesn't support the fast path for "
+          "computing the aspect ratios, so will iterate over "
+          "the full dataset and load every image instead. "
+          "This might take some time...")
+    if indices is None:
+        indices = range(len(dataset))
+
+    class SubsetSampler(Sampler):
+        def __init__(self, indices):
+            self.indices = indices
+
+        def __iter__(self):
+            return iter(self.indices)
+
+        def __len__(self):
+            return len(self.indices)
+
+    sampler = SubsetSampler(indices)
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_size=1, sampler=sampler,
+        num_workers=14,  # you might want to increase it for faster processing
+        collate_fn=lambda x: x[0])
+    aspect_ratios = []
+    with tqdm(total=len(dataset)) as pbar:
+        for _i, (img, _) in enumerate(data_loader):
+            pbar.update(1)
+            height, width = img.shape[-2:]
+            aspect_ratio = float(height) / float(width)
+            aspect_ratios.append(aspect_ratio)
+    return aspect_ratios
+
+
+def _compute_aspect_ratios_custom_dataset(dataset, indices=None):
+    if indices is None:
+        indices = range(len(dataset))
+    aspect_ratios = []
+    for i in indices:
+        height, width = dataset.get_height_and_width(i)
+        aspect_ratio = float(height) / float(width)
+        aspect_ratios.append(aspect_ratio)
+    return aspect_ratios
+
+
+def _compute_aspect_ratios_coco_dataset(dataset, indices=None):
+    if indices is None:
+        indices = range(len(dataset))
+    aspect_ratios = []
+    for i in indices:
+        img_info = dataset.coco.imgs[dataset.ids[i]]
+        aspect_ratio = float(img_info["height"]) / float(img_info["width"])
+        aspect_ratios.append(aspect_ratio)
+    return aspect_ratios
+
+
+def _compute_aspect_ratios_voc_dataset(dataset, indices=None):
+    if indices is None:
+        indices = range(len(dataset))
+    aspect_ratios = []
+    for i in indices:
+        # this doesn't load the data into memory, because PIL loads it lazily
+        width, height = Image.open(dataset.images[i]).size
+        aspect_ratio = float(height) / float(width)
+        aspect_ratios.append(aspect_ratio)
+    return aspect_ratios
+
+
+def _compute_aspect_ratios_subset_dataset(dataset, indices=None):
+    if indices is None:
+        indices = range(len(dataset))
+
+    ds_indices = [dataset.indices[i] for i in indices]
+    return compute_aspect_ratios(dataset.dataset, ds_indices)
+
+
+def compute_aspect_ratios(dataset, indices=None):
+    if hasattr(dataset, "get_height_and_width"):
+        return _compute_aspect_ratios_custom_dataset(dataset, indices)
+
+    if isinstance(dataset, torchvision.datasets.CocoDetection):
+        return _compute_aspect_ratios_coco_dataset(dataset, indices)
+
+    if isinstance(dataset, torchvision.datasets.VOCDetection):
+        return _compute_aspect_ratios_voc_dataset(dataset, indices)
+
+    if isinstance(dataset, torch.utils.data.Subset):
+        return _compute_aspect_ratios_subset_dataset(dataset, indices)
+
+    # slow path
+    return _compute_aspect_ratios_slow(dataset, indices)
+
+
+def _quantize(x, bins):
+    bins = copy.deepcopy(bins)
+    bins = sorted(bins)
+    quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
+    return quantized
+
+
+def create_aspect_ratio_groups(dataset, k=0):
+    aspect_ratios = compute_aspect_ratios(dataset)
+    bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0]
+    groups = _quantize(aspect_ratios, bins)
+    # count number of elements per group
+    counts = np.unique(groups, return_counts=True)[1]
+    fbins = [0] + bins + [np.inf]
+    print("Using {} as bins for aspect ratio quantization".format(fbins))
+    print("Count of instances per bin: {}".format(counts))
+    return groups
diff --git a/examples/object_detection_compression/logging.conf b/examples/object_detection_compression/logging.conf
new file mode 100644
index 0000000..429419b
--- /dev/null
+++ b/examples/object_detection_compression/logging.conf
@@ -0,0 +1,54 @@
+[formatters]
+keys: simple, time_simple
+
+[handlers]
+keys: console, file
+
+[loggers]
+keys: root, app_cfg, distiller.thinning, apputils.model_summaries
+
+[formatter_simple]
+format: %(message)s
+
+[formatter_time_simple]
+format: %(asctime)s - %(message)s
+
+[handler_console]
+class: StreamHandler
+propagate: 0
+args: []
+formatter: simple
+
+[handler_file]
+class: FileHandler
+mode: 'w'
+args=('%(logfilename)s', 'w')
+formatter: time_simple
+
+[logger_root]
+level: INFO
+propagate: 1
+handlers: console, file
+
+[logger_app_cfg]
+# Use this logger to log the application configuration and execution environment
+level: DEBUG
+qualname: app_cfg
+propagate: 0
+handlers: file
+
+# Example of adding a module-specific logger
+# Do not forget to add distiller.thinning to the list of keys in section [loggers]
+[logger_distiller.thinning]
+level: INFO
+qualname: distiller.thinning
+propagate: 0
+handlers: console, file
+
+# Example of adding a module-specific logger
+# Do not forget to add apputils.model_summaries to the list of keys in section [loggers]
+[logger_apputils.model_summaries]
+level: INFO
+qualname: apputils.model_summaries
+propagate: 0
+handlers: console, file
diff --git a/examples/object_detection_compression/maskrcnn.scheduler_agp.non_parallel.yaml b/examples/object_detection_compression/maskrcnn.scheduler_agp.non_parallel.yaml
new file mode 100644
index 0000000..579582b
--- /dev/null
+++ b/examples/object_detection_compression/maskrcnn.scheduler_agp.non_parallel.yaml
@@ -0,0 +1,210 @@
+#
+# Command line:
+# python compress_detector.py --data-path $DATASET_COCO \
+#     --compress maskrcnn.scheduler_agp.non_parallel.yaml
+#
+#Parameters:
+#+----+-------------------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+#|    | Name                                            | Shape              |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+#|----+-------------------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+#|  0 | backbone.body.conv1.weight                      | (64, 3, 7, 7)      |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12322 | -0.00050 |    0.07605 |
+#|  1 | backbone.body.layer1.0.conv1.weight             | (64, 64, 1, 1)     |          4096 |           4096 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.07083 | -0.00439 |    0.04032 |
+#|  2 | backbone.body.layer1.0.conv2.weight             | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02915 |  0.00079 |    0.01717 |
+#|  3 | backbone.body.layer1.0.conv3.weight             | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03527 |  0.00042 |    0.02105 |
+#|  4 | backbone.body.layer1.0.downsample.0.weight      | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.05750 | -0.00342 |    0.03227 |
+#|  5 | backbone.body.layer1.1.conv1.weight             | (64, 256, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03062 |  0.00135 |    0.02029 |
+#|  6 | backbone.body.layer1.1.conv2.weight             | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02881 |  0.00005 |    0.01964 |
+#|  7 | backbone.body.layer1.1.conv3.weight             | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03254 | -0.00019 |    0.02029 |
+#|  8 | backbone.body.layer1.2.conv1.weight             | (64, 256, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03008 |  0.00012 |    0.02169 |
+#|  9 | backbone.body.layer1.2.conv2.weight             | (64, 64, 3, 3)     |         36864 |          36864 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03178 | -0.00085 |    0.02378 |
+#| 10 | backbone.body.layer1.2.conv3.weight             | (256, 64, 1, 1)    |         16384 |          16384 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03109 | -0.00255 |    0.01857 |
+#| 11 | backbone.body.layer2.0.conv1.weight             | (128, 256, 1, 1)   |         32768 |          32768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03542 | -0.00152 |    0.02484 |
+#| 12 | backbone.body.layer2.0.conv2.weight             | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02229 | -0.00054 |    0.01661 |
+#| 13 | backbone.body.layer2.0.conv3.weight             | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02808 | -0.00013 |    0.01740 |
+#| 14 | backbone.body.layer2.0.downsample.0.weight      | (512, 256, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02311 | -0.00035 |    0.01352 |
+#| 15 | backbone.body.layer2.1.conv1.weight             | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01666 | -0.00007 |    0.01004 |
+#| 16 | backbone.body.layer2.1.conv2.weight             | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01934 | -0.00004 |    0.01239 |
+#| 17 | backbone.body.layer2.1.conv3.weight             | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02206 | -0.00121 |    0.01247 |
+#| 18 | backbone.body.layer2.2.conv1.weight             | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02332 | -0.00078 |    0.01618 |
+#| 19 | backbone.body.layer2.2.conv2.weight             | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02159 | -0.00033 |    0.01536 |
+#| 20 | backbone.body.layer2.2.conv3.weight             | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02617 | -0.00047 |    0.01840 |
+#| 21 | backbone.body.layer2.3.conv1.weight             | (128, 512, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02432 | -0.00098 |    0.01803 |
+#| 22 | backbone.body.layer2.3.conv2.weight             | (128, 128, 3, 3)   |        147456 |         147456 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02241 | -0.00074 |    0.01691 |
+#| 23 | backbone.body.layer2.3.conv3.weight             | (512, 128, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02432 | -0.00120 |    0.01681 |
+#| 24 | backbone.body.layer3.0.conv1.weight             | (256, 512, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03077 | -0.00127 |    0.02202 |
+#| 25 | backbone.body.layer3.0.conv2.weight             | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01736 | -0.00041 |    0.01279 |
+#| 26 | backbone.body.layer3.0.conv3.weight             | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02338 | -0.00042 |    0.01660 |
+#| 27 | backbone.body.layer3.0.downsample.0.weight      | (1024, 512, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01615 |  0.00004 |    0.01106 |
+#| 28 | backbone.body.layer3.1.conv1.weight             | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01521 | -0.00055 |    0.01063 |
+#| 29 | backbone.body.layer3.1.conv2.weight             | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01510 | -0.00031 |    0.01102 |
+#| 30 | backbone.body.layer3.1.conv3.weight             | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02036 | -0.00109 |    0.01457 |
+#| 31 | backbone.body.layer3.2.conv1.weight             | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01575 | -0.00052 |    0.01125 |
+#| 32 | backbone.body.layer3.2.conv2.weight             | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01500 | -0.00071 |    0.01127 |
+#| 33 | backbone.body.layer3.2.conv3.weight             | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01908 | -0.00077 |    0.01393 |
+#| 34 | backbone.body.layer3.3.conv1.weight             | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01734 | -0.00077 |    0.01283 |
+#| 35 | backbone.body.layer3.3.conv2.weight             | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01493 | -0.00077 |    0.01142 |
+#| 36 | backbone.body.layer3.3.conv3.weight             | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01821 | -0.00117 |    0.01341 |
+#| 37 | backbone.body.layer3.4.conv1.weight             | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01811 | -0.00101 |    0.01362 |
+#| 38 | backbone.body.layer3.4.conv2.weight             | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01487 | -0.00094 |    0.01141 |
+#| 39 | backbone.body.layer3.4.conv3.weight             | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01815 | -0.00160 |    0.01334 |
+#| 40 | backbone.body.layer3.5.conv1.weight             | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01970 | -0.00096 |    0.01498 |
+#| 41 | backbone.body.layer3.5.conv2.weight             | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01524 | -0.00090 |    0.01170 |
+#| 42 | backbone.body.layer3.5.conv3.weight             | (1024, 256, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01923 | -0.00231 |    0.01438 |
+#| 43 | backbone.body.layer4.0.conv1.weight             | (512, 1024, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02380 | -0.00136 |    0.01837 |
+#| 44 | backbone.body.layer4.0.conv2.weight             | (512, 512, 3, 3)   |       2359296 |        2359296 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01255 | -0.00051 |    0.00981 |
+#| 45 | backbone.body.layer4.0.conv3.weight             | (2048, 512, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01529 | -0.00061 |    0.01172 |
+#| 46 | backbone.body.layer4.0.downsample.0.weight      | (2048, 1024, 1, 1) |       2097152 |        2097152 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00994 | -0.00007 |    0.00750 |
+#| 47 | backbone.body.layer4.1.conv1.weight             | (512, 2048, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01472 | -0.00092 |    0.01143 |
+#| 48 | backbone.body.layer4.1.conv2.weight             | (512, 512, 3, 3)   |       2359296 |        2359296 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01230 | -0.00084 |    0.00971 |
+#| 49 | backbone.body.layer4.1.conv3.weight             | (2048, 512, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01495 | -0.00011 |    0.01149 |
+#| 50 | backbone.body.layer4.2.conv1.weight             | (512, 2048, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01813 | -0.00050 |    0.01413 |
+#| 51 | backbone.body.layer4.2.conv2.weight             | (512, 512, 3, 3)   |       2359296 |        2359296 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01104 | -0.00066 |    0.00874 |
+#| 52 | backbone.body.layer4.2.conv3.weight             | (2048, 512, 1, 1)  |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01417 | -0.00003 |    0.01062 |
+#| 53 | backbone.fpn.inner_blocks.0.weight              | (256, 256, 1, 1)   |         65536 |          65536 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03602 |  0.00015 |    0.03119 |
+#| 54 | backbone.fpn.inner_blocks.1.weight              | (256, 512, 1, 1)   |        131072 |         131072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02552 | -0.00001 |    0.02209 |
+#| 55 | backbone.fpn.inner_blocks.2.weight              | (256, 1024, 1, 1)  |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01804 | -0.00005 |    0.01563 |
+#| 56 | backbone.fpn.inner_blocks.3.weight              | (256, 2048, 1, 1)  |        524288 |         524288 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01276 |  0.00001 |    0.01105 |
+#| 57 | backbone.fpn.layer_blocks.0.weight              | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01202 | -0.00001 |    0.01041 |
+#| 58 | backbone.fpn.layer_blocks.1.weight              | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01204 | -0.00000 |    0.01043 |
+#| 59 | backbone.fpn.layer_blocks.2.weight              | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01204 | -0.00001 |    0.01043 |
+#| 60 | backbone.fpn.layer_blocks.3.weight              | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01202 |  0.00001 |    0.01041 |
+#| 61 | rpn.head.conv.weight                            | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00999 | -0.00001 |    0.00797 |
+#| 62 | rpn.head.cls_logits.weight                      | (3, 256, 1, 1)     |           768 |            768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00978 |  0.00015 |    0.00772 |
+#| 63 | rpn.head.bbox_pred.weight                       | (12, 256, 1, 1)    |          3072 |           3072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.00982 | -0.00034 |    0.00786 |
+#| 64 | roi_heads.box_head.fc6.weight                   | (1024, 12544)      |      12845056 |       12845053 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00002 | 0.00515 | -0.00000 |    0.00446 |
+#| 65 | roi_heads.box_head.fc7.weight                   | (1024, 1024)       |       1048576 |        1048576 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01805 | -0.00003 |    0.01563 |
+#| 66 | roi_heads.box_predictor.cls_score.weight        | (91, 1024)         |         93184 |          93184 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01806 |  0.00009 |    0.01564 |
+#| 67 | roi_heads.box_predictor.bbox_pred.weight        | (364, 1024)        |        372736 |         372736 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01805 |  0.00001 |    0.01563 |
+#| 68 | roi_heads.mask_head.mask_fcn1.weight            | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02944 |  0.00004 |    0.02350 |
+#| 69 | roi_heads.mask_head.mask_fcn2.weight            | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02947 | -0.00004 |    0.02351 |
+#| 70 | roi_heads.mask_head.mask_fcn3.weight            | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02949 |  0.00004 |    0.02352 |
+#| 71 | roi_heads.mask_head.mask_fcn4.weight            | (256, 256, 3, 3)   |        589824 |         589824 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.02949 |  0.00000 |    0.02354 |
+#| 72 | roi_heads.mask_predictor.conv5_mask.weight      | (256, 256, 2, 2)   |        262144 |         262144 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04420 | -0.00003 |    0.03524 |
+#| 73 | roi_heads.mask_predictor.mask_fcn_logits.weight | (91, 256, 1, 1)    |         23296 |          23296 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.14817 |  0.00125 |    0.11834 |
+#| 74 | Total sparsity:                                 | -                  |      44395200 |       44395197 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00001 | 0.00000 |  0.00000 |    0.00000 |
+#+----+-------------------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+
+version: 1
+
+pruners:
+
+  fc_pruner:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.01
+    final_sparsity: 0.85
+    weights: [
+      roi_heads.box_head.fc6.weight,
+      roi_heads.box_head.fc7.weight
+    ]
+
+  agp_pruner_75:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.01
+    final_sparsity: 0.75
+    weights: [
+      backbone.body.layer1.0.conv1.weight,
+      backbone.body.layer1.0.conv2.weight,
+      backbone.body.layer1.0.conv3.weight,
+      backbone.body.layer1.0.downsample.0.weight,
+      backbone.body.layer1.1.conv1.weight,
+      backbone.body.layer1.1.conv2.weight,
+      backbone.body.layer1.1.conv3.weight,
+      backbone.body.layer1.2.conv1.weight,
+      backbone.body.layer1.2.conv2.weight,
+      backbone.body.layer1.2.conv3.weight,
+      backbone.body.layer2.0.conv1.weight,
+      backbone.body.layer2.0.conv2.weight,
+      backbone.body.layer2.0.conv3.weight,]
+
+  agp_pruner_85:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.01
+    final_sparsity: 0.85
+    weights: [
+      backbone.body.layer2.0.downsample.0.weight,
+      backbone.body.layer2.1.conv1.weight,
+      backbone.body.layer2.1.conv2.weight,
+      backbone.body.layer2.1.conv3.weight,
+      backbone.body.layer2.2.conv1.weight,
+      backbone.body.layer2.2.conv2.weight,
+      backbone.body.layer2.2.conv3.weight,
+      backbone.body.layer2.3.conv1.weight,
+      backbone.body.layer2.3.conv2.weight,
+      backbone.body.layer2.3.conv3.weight,
+      backbone.body.layer3.0.conv1.weight,
+      backbone.body.layer3.0.conv2.weight,
+      backbone.body.layer3.0.conv3.weight,
+      backbone.body.layer3.0.downsample.0.weight,
+      backbone.body.layer3.1.conv1.weight,
+      backbone.body.layer3.1.conv2.weight,
+      backbone.body.layer3.1.conv3.weight,
+      backbone.body.layer3.2.conv1.weight,
+      backbone.body.layer3.2.conv2.weight,
+      backbone.body.layer3.2.conv3.weight,
+      backbone.body.layer3.3.conv1.weight,
+      backbone.body.layer3.3.conv2.weight,
+      backbone.body.layer3.3.conv3.weight,
+      backbone.body.layer3.4.conv1.weight,
+      backbone.body.layer3.4.conv2.weight,
+      backbone.body.layer3.4.conv3.weight,
+      backbone.body.layer3.5.conv1.weight,
+      backbone.body.layer3.5.conv2.weight,
+      backbone.body.layer3.5.conv3.weight,
+      backbone.body.layer4.2.conv3.weight,
+      backbone.fpn.inner_blocks.0.weight,
+      backbone.fpn.inner_blocks.1.weight,
+      backbone.fpn.inner_blocks.2.weight,
+      backbone.fpn.inner_blocks.3.weight,
+      backbone.fpn.layer_blocks.0.weight,
+      backbone.fpn.layer_blocks.1.weight,
+      backbone.fpn.layer_blocks.2.weight,
+      backbone.fpn.layer_blocks.3.weight,
+      rpn.head.conv.weight,
+      roi_heads.mask_head.mask_fcn1.weight,
+      roi_heads.mask_head.mask_fcn2.weight,
+      roi_heads.mask_head.mask_fcn3.weight,
+      roi_heads.mask_head.mask_fcn4.weight,
+      roi_heads.mask_predictor.conv5_mask.weight,
+    ]
+
+  agp_pruner_90:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.01
+    final_sparsity: 0.90
+    weights: [
+    backbone.body.layer4.0.conv1.weight,
+    backbone.body.layer4.0.conv2.weight,
+    backbone.body.layer4.0.conv3.weight,
+    backbone.body.layer4.0.downsample.0.weight,
+    backbone.body.layer4.1.conv1.weight,
+    backbone.body.layer4.1.conv2.weight,
+    backbone.body.layer4.1.conv3.weight,
+    backbone.body.layer4.2.conv1.weight,
+    backbone.body.layer4.2.conv2.weight,
+    ]
+
+
+policies:
+  - pruner:
+     instance_name : agp_pruner_75
+    starting_epoch: 0
+    ending_epoch: 45
+    frequency: 1
+
+  - pruner:
+     instance_name : agp_pruner_85
+    starting_epoch: 0
+    ending_epoch: 45
+    frequency: 1
+
+  - pruner:
+      instance_name : fc_pruner
+    starting_epoch: 0
+    ending_epoch: 45
+    frequency: 3
+
+  - pruner:
+      instance_name : agp_pruner_90
+    starting_epoch: 0
+    ending_epoch: 45
+    frequency: 1
diff --git a/examples/object_detection_compression/maskrcnn.scheduler_agp.yaml b/examples/object_detection_compression/maskrcnn.scheduler_agp.yaml
new file mode 100644
index 0000000..fda56c6
--- /dev/null
+++ b/examples/object_detection_compression/maskrcnn.scheduler_agp.yaml
@@ -0,0 +1,267 @@
+#
+# Command line:
+# python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env compress_detector.py --data-path $DATASET_COCO \
+#     --compress maskrcnn.scheduler_agp.yaml --world-size $NGPU
+#
+#Parameters:
+#+----+--------------------------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+#|    | Name                                                   | Shape              |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
+#|----+--------------------------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
+#|  0 | module.backbone.body.conv1.weight                      | (64, 3, 7, 7)      |          9408 |           9408 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.12322 | -0.00050 |    0.07605 |
+#|  1 | module.backbone.body.layer1.0.conv1.weight             | (64, 64, 1, 1)     |          4096 |           1024 |    0.00000 |    0.00000 |  3.12500 | 75.00000 |  7.81250 |   75.00000 | 0.06810 | -0.00616 |    0.02771 |
+#|  2 | module.backbone.body.layer1.0.conv2.weight             | (64, 64, 3, 3)     |         36864 |           9216 |    0.00000 |    0.00000 |  7.81250 | 32.93457 |  6.25000 |   75.00000 | 0.02777 |  0.00070 |    0.01135 |
+#|  3 | module.backbone.body.layer1.0.conv3.weight             | (256, 64, 1, 1)    |         16384 |           4096 |    0.00000 |    0.00000 |  6.25000 | 75.00000 | 12.10938 |   75.00000 | 0.03366 |  0.00035 |    0.01453 |
+#|  4 | module.backbone.body.layer1.0.downsample.0.weight      | (256, 64, 1, 1)    |         16384 |           4096 |    0.00000 |    0.00000 |  1.56250 | 75.00000 | 13.28125 |   75.00000 | 0.05548 | -0.00382 |    0.02256 |
+#|  5 | module.backbone.body.layer1.1.conv1.weight             | (64, 256, 1, 1)    |         16384 |           4096 |    0.00000 |    0.00000 | 11.71875 | 75.00000 |  6.25000 |   75.00000 | 0.02841 |  0.00126 |    0.01292 |
+#|  6 | module.backbone.body.layer1.1.conv2.weight             | (64, 64, 3, 3)     |         36864 |           9216 |    0.00000 |    0.00000 |  6.25000 | 26.46484 |  0.00000 |   75.00000 | 0.02650 |  0.00022 |    0.01178 |
+#|  7 | module.backbone.body.layer1.1.conv3.weight             | (256, 64, 1, 1)    |         16384 |           4096 |    0.00000 |    0.00000 |  0.00000 | 75.00000 |  3.51562 |   75.00000 | 0.03090 |  0.00000 |    0.01370 |
+#|  8 | module.backbone.body.layer1.2.conv1.weight             | (64, 256, 1, 1)    |         16384 |           4096 |    0.00000 |    0.00000 |  7.03125 | 75.00000 |  0.00000 |   75.00000 | 0.02725 |  0.00012 |    0.01270 |
+#|  9 | module.backbone.body.layer1.2.conv2.weight             | (64, 64, 3, 3)     |         36864 |           9216 |    0.00000 |    0.00000 |  0.00000 | 21.19141 |  0.00000 |   75.00000 | 0.02828 | -0.00049 |    0.01323 |
+#| 10 | module.backbone.body.layer1.2.conv3.weight             | (256, 64, 1, 1)    |         16384 |           4096 |    0.00000 |    0.00000 |  0.00000 | 75.00000 |  1.17188 |   75.00000 | 0.02999 | -0.00228 |    0.01325 |
+#| 11 | module.backbone.body.layer2.0.conv1.weight             | (128, 256, 1, 1)   |         32768 |           8192 |    0.00000 |    0.00000 |  3.90625 | 75.00000 |  0.00000 |   75.00000 | 0.04454 |  0.00213 |    0.01983 |
+#| 12 | module.backbone.body.layer2.0.conv2.weight             | (128, 128, 3, 3)   |        147456 |          36864 |    0.00000 |    0.00000 |  0.00000 | 28.43018 |  0.00000 |   75.00000 | 0.02019 |  0.00093 |    0.00941 |
+#| 13 | module.backbone.body.layer2.0.conv3.weight             | (512, 128, 1, 1)   |         65536 |          16384 |    0.00000 |    0.00000 |  0.00000 | 75.00000 | 28.90625 |   75.00000 | 0.03100 |  0.00207 |    0.01292 |
+#| 14 | module.backbone.body.layer2.0.downsample.0.weight      | (512, 256, 1, 1)   |        131072 |          19661 |    0.00000 |    0.00000 |  0.00000 | 84.99985 | 15.62500 |   84.99985 | 0.01999 |  0.00119 |    0.00595 |
+#| 15 | module.backbone.body.layer2.1.conv1.weight             | (128, 512, 1, 1)   |         65536 |           9831 |    0.00000 |    0.00000 | 17.57812 | 84.99908 |  0.00000 |   84.99908 | 0.01914 |  0.00124 |    0.00620 |
+#| 16 | module.backbone.body.layer2.1.conv2.weight             | (128, 128, 3, 3)   |        147456 |          22119 |    0.00000 |    0.00000 |  0.00000 | 53.41187 |  0.00000 |   84.99959 | 0.01481 |  0.00072 |    0.00531 |
+#| 17 | module.backbone.body.layer2.1.conv3.weight             | (512, 128, 1, 1)   |         65536 |           9831 |    0.00000 |    0.00000 |  0.00000 | 84.99908 | 38.28125 |   84.99908 | 0.02156 |  0.00029 |    0.00685 |
+#| 18 | module.backbone.body.layer2.2.conv1.weight             | (128, 512, 1, 1)   |         65536 |           9831 |    0.00000 |    0.00000 |  4.29688 | 84.99908 |  0.00000 |   84.99908 | 0.01961 |  0.00100 |    0.00688 |
+#| 19 | module.backbone.body.layer2.2.conv2.weight             | (128, 128, 3, 3)   |        147456 |          22119 |    0.00000 |    0.00000 |  0.00000 | 41.84570 |  0.00000 |   84.99959 | 0.01564 |  0.00037 |    0.00562 |
+#| 20 | module.backbone.body.layer2.2.conv3.weight             | (512, 128, 1, 1)   |         65536 |           9831 |    0.00000 |    0.00000 |  0.00000 | 84.99908 |  6.44531 |   84.99908 | 0.02155 |  0.00039 |    0.00747 |
+#| 21 | module.backbone.body.layer2.3.conv1.weight             | (128, 512, 1, 1)   |         65536 |           9831 |    0.00000 |    0.00000 |  4.10156 | 84.99908 |  0.00000 |   84.99908 | 0.02108 |  0.00067 |    0.00752 |
+#| 22 | module.backbone.body.layer2.3.conv2.weight             | (128, 128, 3, 3)   |        147456 |          22119 |    0.00000 |    0.00000 |  0.00000 | 36.71265 |  0.00000 |   84.99959 | 0.01695 |  0.00009 |    0.00623 |
+#| 23 | module.backbone.body.layer2.3.conv3.weight             | (512, 128, 1, 1)   |         65536 |           9831 |    0.00000 |    0.00000 |  0.00000 | 84.99908 | 24.80469 |   84.99908 | 0.02311 | -0.00029 |    0.00796 |
+#| 24 | module.backbone.body.layer3.0.conv1.weight             | (256, 512, 1, 1)   |        131072 |          19661 |    0.00000 |    0.00000 |  0.19531 | 84.99985 |  0.00000 |   84.99985 | 0.02292 |  0.00031 |    0.00801 |
+#| 25 | module.backbone.body.layer3.0.conv2.weight             | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 47.98279 |  0.00000 |   84.99993 | 0.01155 |  0.00032 |    0.00421 |
+#| 26 | module.backbone.body.layer3.0.conv3.weight             | (1024, 256, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  6.34766 |   84.99985 | 0.01704 |  0.00087 |    0.00601 |
+#| 27 | module.backbone.body.layer3.0.downsample.0.weight      | (1024, 512, 1, 1)  |        524288 |          78644 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  4.39453 |   84.99985 | 0.01072 |  0.00059 |    0.00370 |
+#| 28 | module.backbone.body.layer3.1.conv1.weight             | (256, 1024, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  9.08203 | 84.99985 |  0.00000 |   84.99985 | 0.01141 |  0.00051 |    0.00392 |
+#| 29 | module.backbone.body.layer3.1.conv2.weight             | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 46.10596 |  0.00000 |   84.99993 | 0.00989 |  0.00000 |    0.00357 |
+#| 30 | module.backbone.body.layer3.1.conv3.weight             | (1024, 256, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  1.36719 |   84.99985 | 0.01376 | -0.00064 |    0.00485 |
+#| 31 | module.backbone.body.layer3.2.conv1.weight             | (256, 1024, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  2.24609 | 84.99985 |  0.00000 |   84.99985 | 0.01204 |  0.00033 |    0.00413 |
+#| 32 | module.backbone.body.layer3.2.conv2.weight             | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 37.43591 |  0.00000 |   84.99993 | 0.01000 | -0.00007 |    0.00365 |
+#| 33 | module.backbone.body.layer3.2.conv3.weight             | (1024, 256, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  1.85547 |   84.99985 | 0.01354 | -0.00039 |    0.00477 |
+#| 34 | module.backbone.body.layer3.3.conv1.weight             | (256, 1024, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.97656 | 84.99985 |  0.00000 |   84.99985 | 0.01287 |  0.00016 |    0.00452 |
+#| 35 | module.backbone.body.layer3.3.conv2.weight             | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 36.71875 |  0.00000 |   84.99993 | 0.01004 |  0.00000 |    0.00368 |
+#| 36 | module.backbone.body.layer3.3.conv3.weight             | (1024, 256, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  6.64062 |   84.99985 | 0.01315 | -0.00024 |    0.00462 |
+#| 37 | module.backbone.body.layer3.4.conv1.weight             | (256, 1024, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.29297 | 84.99985 |  0.00000 |   84.99985 | 0.01321 |  0.00005 |    0.00468 |
+#| 38 | module.backbone.body.layer3.4.conv2.weight             | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 38.03711 |  0.00000 |   84.99993 | 0.01016 | -0.00003 |    0.00370 |
+#| 39 | module.backbone.body.layer3.4.conv3.weight             | (1024, 256, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  8.00781 |   84.99985 | 0.01311 | -0.00017 |    0.00457 |
+#| 40 | module.backbone.body.layer3.5.conv1.weight             | (256, 1024, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.29297 | 84.99985 |  0.00000 |   84.99985 | 0.01289 |  0.00021 |    0.00455 |
+#| 41 | module.backbone.body.layer3.5.conv2.weight             | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 40.31677 |  0.00000 |   84.99993 | 0.00959 |  0.00013 |    0.00349 |
+#| 42 | module.backbone.body.layer3.5.conv3.weight             | (1024, 256, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  5.76172 |   84.99985 | 0.01286 | -0.00098 |    0.00458 |
+#| 43 | module.backbone.body.layer4.0.conv1.weight             | (512, 1024, 1, 1)  |        524288 |          52429 |    0.00000 |    0.00000 |  0.00000 | 89.99996 |  0.00000 |   89.99996 | 0.01248 |  0.00061 |    0.00369 |
+#| 44 | module.backbone.body.layer4.0.conv2.weight             | (512, 512, 3, 3)   |       2359296 |         235930 |    0.00000 |    0.00000 |  0.00000 | 57.85141 |  0.00000 |   89.99998 | 0.00621 |  0.00005 |    0.00189 |
+#| 45 | module.backbone.body.layer4.0.conv3.weight             | (2048, 512, 1, 1)  |       1048576 |         104858 |    0.00000 |    0.00000 |  0.00000 | 89.99996 |  0.29297 |   89.99996 | 0.00865 |  0.00026 |    0.00259 |
+#| 46 | module.backbone.body.layer4.0.downsample.0.weight      | (2048, 1024, 1, 1) |       2097152 |         209716 |    0.00000 |    0.00000 |  0.00000 | 89.99996 |  0.04883 |   89.99996 | 0.00611 | -0.00007 |    0.00184 |
+#| 47 | module.backbone.body.layer4.1.conv1.weight             | (512, 2048, 1, 1)  |       1048576 |         104858 |    0.00000 |    0.00000 |  0.24414 | 89.99996 |  0.00000 |   89.99996 | 0.00817 |  0.00041 |    0.00238 |
+#| 48 | module.backbone.body.layer4.1.conv2.weight             | (512, 512, 3, 3)   |       2359296 |         235930 |    0.00000 |    0.00000 |  0.00000 | 57.28302 |  0.00000 |   89.99998 | 0.00649 |  0.00001 |    0.00196 |
+#| 49 | module.backbone.body.layer4.1.conv3.weight             | (2048, 512, 1, 1)  |       1048576 |         104858 |    0.00000 |    0.00000 |  0.00000 | 89.99996 |  0.24414 |   89.99996 | 0.00867 | -0.00013 |    0.00259 |
+#| 50 | module.backbone.body.layer4.2.conv1.weight             | (512, 2048, 1, 1)  |       1048576 |         104858 |    0.00000 |    0.00000 |  0.04883 | 89.99996 |  0.00000 |   89.99996 | 0.00876 |  0.00056 |    0.00259 |
+#| 51 | module.backbone.body.layer4.2.conv2.weight             | (512, 512, 3, 3)   |       2359296 |         235930 |    0.00000 |    0.00000 |  0.00000 | 59.87663 |  0.00000 |   89.99998 | 0.00634 |  0.00001 |    0.00189 |
+#| 52 | module.backbone.body.layer4.2.conv3.weight             | (2048, 512, 1, 1)  |       1048576 |         157287 |    0.00000 |    0.00000 |  0.00000 | 84.99994 |  0.04883 |   84.99994 | 0.00883 | -0.00022 |    0.00311 |
+#| 53 | module.backbone.fpn.inner_blocks.0.weight              | (256, 256, 1, 1)   |         65536 |           9831 |    0.00000 |    0.00000 |  9.37500 | 84.99908 |  0.00000 |   84.99908 | 0.00920 |  0.00016 |    0.00340 |
+#| 54 | module.backbone.fpn.inner_blocks.1.weight              | (256, 512, 1, 1)   |        131072 |          19661 |    0.00000 |    0.00000 |  0.58594 | 84.99985 |  0.00000 |   84.99985 | 0.01087 |  0.00008 |    0.00403 |
+#| 55 | module.backbone.fpn.inner_blocks.2.weight              | (256, 1024, 1, 1)  |        262144 |          39322 |    0.00000 |    0.00000 |  0.09766 | 84.99985 |  0.00000 |   84.99985 | 0.00976 | -0.00003 |    0.00358 |
+#| 56 | module.backbone.fpn.inner_blocks.3.weight              | (256, 2048, 1, 1)  |        524288 |          78644 |    0.00000 |    0.00000 |  0.00000 | 84.99985 |  0.00000 |   84.99985 | 0.01260 | -0.00003 |    0.00470 |
+#| 57 | module.backbone.fpn.layer_blocks.0.weight              | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 43.46161 |  0.00000 |   84.99993 | 0.00841 | -0.00003 |    0.00318 |
+#| 58 | module.backbone.fpn.layer_blocks.1.weight              | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 50.99182 |  0.00000 |   84.99993 | 0.00608 | -0.00001 |    0.00230 |
+#| 59 | module.backbone.fpn.layer_blocks.2.weight              | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 50.06256 |  0.00000 |   84.99993 | 0.00655 | -0.00001 |    0.00248 |
+#| 60 | module.backbone.fpn.layer_blocks.3.weight              | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 49.36371 |  0.00000 |   84.99993 | 0.00635 | -0.00003 |    0.00240 |
+#| 61 | module.rpn.head.conv.weight                            | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 41.99219 |  0.00000 |   84.99993 | 0.00580 |  0.00014 |    0.00215 |
+#| 62 | module.rpn.head.cls_logits.weight                      | (3, 256, 1, 1)     |           768 |            768 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.15008 | -0.01621 |    0.11047 |
+#| 63 | module.rpn.head.bbox_pred.weight                       | (12, 256, 1, 1)    |          3072 |           3072 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03009 | -0.00090 |    0.01457 |
+#| 64 | module.roi_heads.box_head.fc6.weight                   | (1024, 12544)      |      12845056 |        1926759 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   85.00000 | 0.00312 |  0.00001 |    0.00115 |
+#| 65 | module.roi_heads.box_head.fc7.weight                   | (1024, 1024)       |       1048576 |         157287 |    1.66016 |    0.00000 |  0.00000 |  1.66016 |  0.00000 |   84.99994 | 0.01011 | -0.00015 |    0.00357 |
+#| 66 | module.roi_heads.box_predictor.cls_score.weight        | (91, 1024)         |         93184 |          93184 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.03252 | -0.00000 |    0.02113 |
+#| 67 | module.roi_heads.box_predictor.bbox_pred.weight        | (364, 1024)        |        372736 |         372736 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.01165 |  0.00011 |    0.00589 |
+#| 68 | module.roi_heads.mask_head.mask_fcn1.weight            | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 40.37628 |  0.00000 |   84.99993 | 0.00778 | -0.00003 |    0.00290 |
+#| 69 | module.roi_heads.mask_head.mask_fcn2.weight            | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 45.38727 |  0.00000 |   84.99993 | 0.00810 | -0.00004 |    0.00294 |
+#| 70 | module.roi_heads.mask_head.mask_fcn3.weight            | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.00000 | 50.18463 |  1.95312 |   84.99993 | 0.00791 | -0.00003 |    0.00282 |
+#| 71 | module.roi_heads.mask_head.mask_fcn4.weight            | (256, 256, 3, 3)   |        589824 |          88474 |    0.00000 |    0.00000 |  0.78125 | 61.83777 | 19.92188 |   84.99993 | 0.00762 |  0.00054 |    0.00259 |
+#| 72 | module.roi_heads.mask_predictor.conv5_mask.weight      | (256, 256, 2, 2)   |        262144 |          39322 |    0.00000 |    0.00000 |  7.81250 | 76.00861 | 19.53125 |   84.99985 | 0.01107 |  0.00141 |    0.00371 |
+#| 73 | module.roi_heads.mask_predictor.mask_fcn_logits.weight | (91, 256, 1, 1)    |         23296 |          23296 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |    0.00000 | 0.04107 | -0.00201 |    0.02535 |
+#| 74 | Total sparsity:                                        | -                  |      44395200 |        6437593 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   85.49935 | 0.00000 |  0.00000 |    0.00000 |
+#+----+--------------------------------------------------------+--------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
+#Total sparsity: 85.50
+#Results:
+# Baseline:
+  #IoU metric: bbox
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.379
+  # Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.592
+  # Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.410
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.215
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.414
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.495
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.312
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.494
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.518
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.321
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.559
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.660
+  #IoU metric: segm
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.346
+  # Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.561
+  # Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.367
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.156
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.373
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.509
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.294
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.454
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.474
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.269
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.515
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.631
+
+#Post Pruning:
+  #IoU metric: bbox
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.352
+  # Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.558
+  # Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.376
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.188
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.389
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.460
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.302
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.481
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.504
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.300
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.546
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.646
+  #IoU metric: segm
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.320
+  # Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.525
+  # Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.338
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.134
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.346
+  # Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.482
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.284
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.439
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.459
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.250
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.499
+  # Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.618
+
+version: 1
+
+pruners:
+
+  fc_pruner:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.01
+    final_sparsity: 0.85
+    weights: [
+      module.roi_heads.box_head.fc6.weight,
+      module.roi_heads.box_head.fc7.weight
+    ]
+
+  agp_pruner_75:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.01
+    final_sparsity: 0.75
+    weights: [
+      module.backbone.body.layer1.0.conv1.weight,
+      module.backbone.body.layer1.0.conv2.weight,
+      module.backbone.body.layer1.0.conv3.weight,
+      module.backbone.body.layer1.0.downsample.0.weight,
+      module.backbone.body.layer1.1.conv1.weight,
+      module.backbone.body.layer1.1.conv2.weight,
+      module.backbone.body.layer1.1.conv3.weight,
+      module.backbone.body.layer1.2.conv1.weight,
+      module.backbone.body.layer1.2.conv2.weight,
+      module.backbone.body.layer1.2.conv3.weight,
+      module.backbone.body.layer2.0.conv1.weight,
+      module.backbone.body.layer2.0.conv2.weight,
+      module.backbone.body.layer2.0.conv3.weight,]
+
+  agp_pruner_85:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.01
+    final_sparsity: 0.85
+    weights: [
+      module.backbone.body.layer2.0.downsample.0.weight,
+      module.backbone.body.layer2.1.conv1.weight,
+      module.backbone.body.layer2.1.conv2.weight,
+      module.backbone.body.layer2.1.conv3.weight,
+      module.backbone.body.layer2.2.conv1.weight,
+      module.backbone.body.layer2.2.conv2.weight,
+      module.backbone.body.layer2.2.conv3.weight,
+      module.backbone.body.layer2.3.conv1.weight,
+      module.backbone.body.layer2.3.conv2.weight,
+      module.backbone.body.layer2.3.conv3.weight,
+      module.backbone.body.layer3.0.conv1.weight,
+      module.backbone.body.layer3.0.conv2.weight,
+      module.backbone.body.layer3.0.conv3.weight,
+      module.backbone.body.layer3.0.downsample.0.weight,
+      module.backbone.body.layer3.1.conv1.weight,
+      module.backbone.body.layer3.1.conv2.weight,
+      module.backbone.body.layer3.1.conv3.weight,
+      module.backbone.body.layer3.2.conv1.weight,
+      module.backbone.body.layer3.2.conv2.weight,
+      module.backbone.body.layer3.2.conv3.weight,
+      module.backbone.body.layer3.3.conv1.weight,
+      module.backbone.body.layer3.3.conv2.weight,
+      module.backbone.body.layer3.3.conv3.weight,
+      module.backbone.body.layer3.4.conv1.weight,
+      module.backbone.body.layer3.4.conv2.weight,
+      module.backbone.body.layer3.4.conv3.weight,
+      module.backbone.body.layer3.5.conv1.weight,
+      module.backbone.body.layer3.5.conv2.weight,
+      module.backbone.body.layer3.5.conv3.weight,
+      module.backbone.body.layer4.2.conv3.weight,
+      module.backbone.fpn.inner_blocks.0.weight,
+      module.backbone.fpn.inner_blocks.1.weight,
+      module.backbone.fpn.inner_blocks.2.weight,
+      module.backbone.fpn.inner_blocks.3.weight,
+      module.backbone.fpn.layer_blocks.0.weight,
+      module.backbone.fpn.layer_blocks.1.weight,
+      module.backbone.fpn.layer_blocks.2.weight,
+      module.backbone.fpn.layer_blocks.3.weight,
+      module.rpn.head.conv.weight,
+      module.roi_heads.mask_head.mask_fcn1.weight,
+      module.roi_heads.mask_head.mask_fcn2.weight,
+      module.roi_heads.mask_head.mask_fcn3.weight,
+      module.roi_heads.mask_head.mask_fcn4.weight,
+      module.roi_heads.mask_predictor.conv5_mask.weight,
+    ]
+
+  agp_pruner_90:
+    class: AutomatedGradualPruner
+    initial_sparsity : 0.01
+    final_sparsity: 0.90
+    weights: [
+    module.backbone.body.layer4.0.conv1.weight,
+    module.backbone.body.layer4.0.conv2.weight,
+    module.backbone.body.layer4.0.conv3.weight,
+    module.backbone.body.layer4.0.downsample.0.weight,
+    module.backbone.body.layer4.1.conv1.weight,
+    module.backbone.body.layer4.1.conv2.weight,
+    module.backbone.body.layer4.1.conv3.weight,
+    module.backbone.body.layer4.2.conv1.weight,
+    module.backbone.body.layer4.2.conv2.weight,
+    ]
+
+
+policies:
+  - pruner:
+     instance_name : agp_pruner_75
+    starting_epoch: 0
+    ending_epoch: 45
+    frequency: 1
+
+  - pruner:
+     instance_name : agp_pruner_85
+    starting_epoch: 0
+    ending_epoch: 45
+    frequency: 1
+
+  - pruner:
+      instance_name : fc_pruner
+    starting_epoch: 0
+    ending_epoch: 45
+    frequency: 3
+
+  - pruner:
+      instance_name : agp_pruner_90
+    starting_epoch: 0
+    ending_epoch: 45
+    frequency: 1
diff --git a/examples/object_detection_compression/requirements.txt b/examples/object_detection_compression/requirements.txt
new file mode 100644
index 0000000..d4e6f64
--- /dev/null
+++ b/examples/object_detection_compression/requirements.txt
@@ -0,0 +1,2 @@
+Cython
+pycocotools
diff --git a/examples/object_detection_compression/transforms.py b/examples/object_detection_compression/transforms.py
new file mode 100644
index 0000000..90a36b3
--- /dev/null
+++ b/examples/object_detection_compression/transforms.py
@@ -0,0 +1,52 @@
+# This code is originally from:
+#   https://github.com/pytorch/vision/tree/v0.4.2/references/detection
+import random
+import torch
+
+from torchvision.transforms import functional as F
+
+
+def _flip_coco_person_keypoints(kps, width):
+    flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
+    flipped_data = kps[:, flip_inds]
+    flipped_data[..., 0] = width - flipped_data[..., 0]
+    # Maintain COCO convention that if visibility == 0, then x, y = 0
+    inds = flipped_data[..., 2] == 0
+    flipped_data[inds] = 0
+    return flipped_data
+
+
+class Compose(object):
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, image, target):
+        for t in self.transforms:
+            image, target = t(image, target)
+        return image, target
+
+
+class RandomHorizontalFlip(object):
+    def __init__(self, prob):
+        self.prob = prob
+
+    def __call__(self, image, target):
+        if random.random() < self.prob:
+            height, width = image.shape[-2:]
+            image = image.flip(-1)
+            bbox = target["boxes"]
+            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
+            target["boxes"] = bbox
+            if "masks" in target:
+                target["masks"] = target["masks"].flip(-1)
+            if "keypoints" in target:
+                keypoints = target["keypoints"]
+                keypoints = _flip_coco_person_keypoints(keypoints, width)
+                target["keypoints"] = keypoints
+        return image, target
+
+
+class ToTensor(object):
+    def __call__(self, image, target):
+        image = F.to_tensor(image)
+        return image, target
diff --git a/examples/object_detection_compression/utils.py b/examples/object_detection_compression/utils.py
new file mode 100644
index 0000000..ce33966
--- /dev/null
+++ b/examples/object_detection_compression/utils.py
@@ -0,0 +1,328 @@
+# This code is originally from:
+#   https://github.com/pytorch/vision/tree/v0.4.2/references/detection
+from __future__ import print_function
+
+from collections import defaultdict, deque
+import datetime
+import pickle
+import time
+
+import torch
+import torch.distributed as dist
+
+import errno
+import os
+
+
+class SmoothedValue(object):
+    """Track a series of values and provide access to smoothed values over a
+    window or the global series average.
+    """
+
+    def __init__(self, window_size=20, fmt=None):
+        if fmt is None:
+            fmt = "{median:.4f} ({global_avg:.4f})"
+        self.deque = deque(maxlen=window_size)
+        self.total = 0.0
+        self.count = 0
+        self.fmt = fmt
+
+    def update(self, value, n=1):
+        self.deque.append(value)
+        self.count += n
+        self.total += value * n
+
+    def synchronize_between_processes(self):
+        """
+        Warning: does not synchronize the deque!
+        """
+        if not is_dist_avail_and_initialized():
+            return
+        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+        dist.barrier()
+        dist.all_reduce(t)
+        t = t.tolist()
+        self.count = int(t[0])
+        self.total = t[1]
+
+    @property
+    def median(self):
+        d = torch.tensor(list(self.deque))
+        return d.median().item()
+
+    @property
+    def avg(self):
+        d = torch.tensor(list(self.deque), dtype=torch.float32)
+        return d.mean().item()
+
+    @property
+    def global_avg(self):
+        return self.total / self.count
+
+    @property
+    def max(self):
+        return max(self.deque)
+
+    @property
+    def value(self):
+        return self.deque[-1]
+
+    def __str__(self):
+        return self.fmt.format(
+            median=self.median,
+            avg=self.avg,
+            global_avg=self.global_avg,
+            max=self.max,
+            value=self.value)
+
+
+def all_gather(data):
+    """
+    Run all_gather on arbitrary picklable data (not necessarily tensors)
+    Args:
+        data: any picklable object
+    Returns:
+        list[data]: list of data gathered from each rank
+    """
+    world_size = get_world_size()
+    if world_size == 1:
+        return [data]
+
+    # serialized to a Tensor
+    buffer = pickle.dumps(data)
+    storage = torch.ByteStorage.from_buffer(buffer)
+    tensor = torch.ByteTensor(storage).to("cuda")
+
+    # obtain Tensor size of each rank
+    local_size = torch.tensor([tensor.numel()], device="cuda")
+    size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
+    dist.all_gather(size_list, local_size)
+    size_list = [int(size.item()) for size in size_list]
+    max_size = max(size_list)
+
+    # receiving Tensor from all ranks
+    # we pad the tensor because torch all_gather does not support
+    # gathering tensors of different shapes
+    tensor_list = []
+    for _ in size_list:
+        tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
+    if local_size != max_size:
+        padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
+        tensor = torch.cat((tensor, padding), dim=0)
+    dist.all_gather(tensor_list, tensor)
+
+    data_list = []
+    for size, tensor in zip(size_list, tensor_list):
+        buffer = tensor.cpu().numpy().tobytes()[:size]
+        data_list.append(pickle.loads(buffer))
+
+    return data_list
+
+
+def reduce_dict(input_dict, average=True):
+    """
+    Args:
+        input_dict (dict): all the values will be reduced
+        average (bool): whether to do average or sum
+    Reduce the values in the dictionary from all processes so that all processes
+    have the averaged results. Returns a dict with the same fields as
+    input_dict, after reduction.
+    """
+    world_size = get_world_size()
+    if world_size < 2:
+        return input_dict
+    with torch.no_grad():
+        names = []
+        values = []
+        # sort the keys so that they are consistent across processes
+        for k in sorted(input_dict.keys()):
+            names.append(k)
+            values.append(input_dict[k])
+        values = torch.stack(values, dim=0)
+        dist.all_reduce(values)
+        if average:
+            values /= world_size
+        reduced_dict = {k: v for k, v in zip(names, values)}
+    return reduced_dict
+
+
+class MetricLogger(object):
+    def __init__(self, delimiter="\t"):
+        self.meters = defaultdict(SmoothedValue)
+        self.delimiter = delimiter
+
+    def update(self, **kwargs):
+        for k, v in kwargs.items():
+            if isinstance(v, torch.Tensor):
+                v = v.item()
+            assert isinstance(v, (float, int))
+            self.meters[k].update(v)
+
+    def __getattr__(self, attr):
+        if attr in self.meters:
+            return self.meters[attr]
+        if attr in self.__dict__:
+            return self.__dict__[attr]
+        raise AttributeError("'{}' object has no attribute '{}'".format(
+            type(self).__name__, attr))
+
+    def __str__(self):
+        loss_str = []
+        for name, meter in self.meters.items():
+            loss_str.append(
+                "{}: {}".format(name, str(meter))
+            )
+        return self.delimiter.join(loss_str)
+
+    def synchronize_between_processes(self):
+        for meter in self.meters.values():
+            meter.synchronize_between_processes()
+
+    def add_meter(self, name, meter):
+        self.meters[name] = meter
+
+    def log_every(self, iterable, print_freq, header=None):
+        i = 0
+        if not header:
+            header = ''
+        start_time = time.time()
+        end = time.time()
+        iter_time = SmoothedValue(fmt='{avg:.4f}')
+        data_time = SmoothedValue(fmt='{avg:.4f}')
+        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+        if torch.cuda.is_available():
+            log_msg = self.delimiter.join([
+                header,
+                '[{0' + space_fmt + '}/{1}]',
+                'eta: {eta}',
+                '{meters}',
+                'time: {time}',
+                'data: {data}',
+                'max mem: {memory:.0f}'
+            ])
+        else:
+            log_msg = self.delimiter.join([
+                header,
+                '[{0' + space_fmt + '}/{1}]',
+                'eta: {eta}',
+                '{meters}',
+                'time: {time}',
+                'data: {data}'
+            ])
+        MB = 1024.0 * 1024.0
+        for obj in iterable:
+            data_time.update(time.time() - end)
+            yield obj
+            iter_time.update(time.time() - end)
+            if i % print_freq == 0 or i == len(iterable) - 1:
+                eta_seconds = iter_time.global_avg * (len(iterable) - i)
+                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+                if torch.cuda.is_available():
+                    print(log_msg.format(
+                        i, len(iterable), eta=eta_string,
+                        meters=str(self),
+                        time=str(iter_time), data=str(data_time),
+                        memory=torch.cuda.max_memory_allocated() / MB))
+                else:
+                    print(log_msg.format(
+                        i, len(iterable), eta=eta_string,
+                        meters=str(self),
+                        time=str(iter_time), data=str(data_time)))
+            i += 1
+            end = time.time()
+        total_time = time.time() - start_time
+        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+        print('{} Total time: {} ({:.4f} s / it)'.format(
+            header, total_time_str, total_time / len(iterable)))
+
+
+def collate_fn(batch):
+    return tuple(zip(*batch))
+
+
+def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
+
+    def f(x):
+        if x >= warmup_iters:
+            return 1
+        alpha = float(x) / warmup_iters
+        return warmup_factor * (1 - alpha) + alpha
+
+    return torch.optim.lr_scheduler.LambdaLR(optimizer, f)
+
+
+def mkdir(path):
+    try:
+        os.makedirs(path)
+    except OSError as e:
+        if e.errno != errno.EEXIST:
+            raise
+
+
+def setup_for_distributed(is_master):
+    """
+    This function disables printing when not in master process
+    """
+    import builtins as __builtin__
+    builtin_print = __builtin__.print
+
+    def print(*args, **kwargs):
+        force = kwargs.pop('force', False)
+        if is_master or force:
+            builtin_print(*args, **kwargs)
+
+    __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+
+def get_world_size():
+    if not is_dist_avail_and_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank():
+    if not is_dist_avail_and_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def is_main_process():
+    return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+    if is_main_process():
+        torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+        args.rank = int(os.environ["RANK"])
+        args.world_size = int(os.environ['WORLD_SIZE'])
+        args.gpu = int(os.environ['LOCAL_RANK'])
+    elif 'SLURM_PROCID' in os.environ:
+        args.rank = int(os.environ['SLURM_PROCID'])
+        args.gpu = args.rank % torch.cuda.device_count()
+    else:
+        print('Not using distributed mode')
+        args.distributed = False
+        return
+
+    args.distributed = True
+
+    torch.cuda.set_device(args.gpu)
+    args.dist_backend = 'nccl'
+    print('| distributed init (rank {}): {}'.format(
+        args.rank, args.dist_url), flush=True)
+    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+                                         world_size=args.world_size, rank=args.rank)
+    torch.distributed.barrier()
+    setup_for_distributed(args.rank == 0)
-- 
GitLab