diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index ed66d78d2e0ac7fa549a6ccc5bb39b09e24071d5..4691ede991e2919b35f9aea4573790385dec042f 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -672,6 +672,7 @@ class FMReconstructionChannelPruner(_RankedStructureParameterPruner): except AttributeError: raise ValueError("To use FMReconstructionChannelPruner you must first collect input statistics") + op_type = 'conv' if param.dim() == 4 else 'fc' # We need to remove the chosen weights channels. Because we are using # min(MSE) to compute the weights, we need to start by removing feature-map # channels from the input. Then we perform the MSE regression to generate diff --git a/examples/auto_compression/amc/environment.py b/examples/auto_compression/amc/environment.py index e6fbcef1859626cfac74086d97bcc7bbbeb8856a..1d73443d617a5a8c8f984afc193472b632b86195 100755 --- a/examples/auto_compression/amc/environment.py +++ b/examples/auto_compression/amc/environment.py @@ -40,7 +40,6 @@ msglogger = logging.getLogger("examples.auto_compression.amc") Observation = namedtuple('Observation', ['t', 'type', 'n', 'c', 'h', 'w', 'stride', 'k', 'MACs', 'Weights', 'reduced', 'rest', 'prev_a']) ObservationLen = len(Observation._fields) -ALMOST_ONE = 0.9999 def is_using_continuous_action_space(agent): @@ -295,7 +294,7 @@ class DistillerWrapperEnvironment(gym.Env): info = {"accuracy": top1, "compress_ratio": normalized_macs} if self.amc_cfg.protocol == "mac-constrained": # Sanity check (special case only for "mac-constrained") - assert self.removed_macs_pct >= 1 - self.amc_cfg.target_density - 0.002 # 0.01 + #assert self.removed_macs_pct >= 1 - self.amc_cfg.target_density - 0.002 # 0.01 pass else: info = {} diff --git a/examples/auto_compression/amc/utils/net_wrapper.py b/examples/auto_compression/amc/utils/net_wrapper.py index 9bdbe3f13251454219823fd5ff0292b9845d2ffa..b447feedbf5c0d3114405d2b29c69fed4977d9d4 100644 --- a/examples/auto_compression/amc/utils/net_wrapper.py +++ b/examples/auto_compression/amc/utils/net_wrapper.py @@ -26,7 +26,7 @@ from distiller import normalize_module_name, SummaryGraph __all__ = ["NetworkWrapper"] msglogger = logging.getLogger() - +ALMOST_ONE = 0.9999 class NetworkWrapper(object): def __init__(self, model, app_args, services, modules_list, pruning_pattern):