Skip to content
Snippets Groups Projects
Commit 6b21a264 authored by Neta Zmora's avatar Neta Zmora
Browse files

environment.py: enforce processing of known networks only

parent ffab00ef
No related branches found
No related tags found
No related merge requests found
...@@ -329,6 +329,7 @@ class DistillerWrapperEnvironment(gym.Env): ...@@ -329,6 +329,7 @@ class DistillerWrapperEnvironment(gym.Env):
distiller.assign_layer_fq_names(model) distiller.assign_layer_fq_names(model)
modules_list = [mod.distiller_name for mod in model.modules() if type(mod)==torch.nn.Conv2d] modules_list = [mod.distiller_name for mod in model.modules() if type(mod)==torch.nn.Conv2d]
msglogger.warning("Using the following layers: %s" % ", ".join(modules_list)) msglogger.warning("Using the following layers: %s" % ", ".join(modules_list))
raise ValueError("The config file does not specify the modules to compress for %s" % app_args.arch)
self.net_wrapper = NetworkWrapper(model, app_args, services, modules_list, amc_cfg.pruning_pattern) self.net_wrapper = NetworkWrapper(model, app_args, services, modules_list, amc_cfg.pruning_pattern)
self.original_model_macs, self.original_model_size = self.net_wrapper.get_resources_requirements() self.original_model_macs, self.original_model_size = self.net_wrapper.get_resources_requirements()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment