Skip to content
Snippets Groups Projects
Unverified Commit 958a52f6 authored by Guy Jacob's avatar Guy Jacob Committed by GitHub
Browse files

Raise error if scheduler dict contains invalid class arguments (#204)

* Until now this would be ignored, yielding unexpected results, such as
  what we saw in issue #191
* Also removed some redundant try-except clauses
parent 5800f35e
No related branches found
No related tags found
No related merge requests found
......@@ -157,13 +157,13 @@ def config_component_from_file_by_class(model, filename, class_name, **extra_arg
config_dict = distiller.utils.yaml_ordered_load(stream)
config_dict.pop('policies', None)
for section_name, components in config_dict.items():
for component_name, component_args in components.items():
if component_args['class'] == class_name:
for component_name, user_args in components.items():
if user_args['class'] == class_name:
msglogger.info(
'Found component of class {0}: Name: {1} ; Section: {2}'.format(class_name, component_name,
section_name))
component_args.update(extra_args)
return build_component(model, component_name, component_args)
user_args.update(extra_args)
return build_component(model, component_name, user_args)
raise ValueError(
'Component of class {0} does not exist in configuration file {1}'.format(class_name, filename))
except yaml.YAMLError:
......@@ -171,55 +171,63 @@ def config_component_from_file_by_class(model, filename, class_name, **extra_arg
raise
def __factory(container_type, model, sched_dict, **kwargs):
def __factory(container_type, model, sched_dict, **extra_args):
container = {}
if container_type in sched_dict:
try:
for name, cfg_kwargs in sched_dict[container_type].items():
try:
cfg_kwargs.update(kwargs)
# Instantiate pruners using the 'class' argument
instance = build_component(model, name, cfg_kwargs)
container[name] = instance
except NameError as error:
print("\nFatal error while parsing [section:%s] [item:%s]" % (container_type, name))
raise
except Exception as exception:
print("\nFatal error while parsing [section:%s] [item:%s]" % (container_type, name))
print("Exception: %s %s" % (type(exception), exception))
raise
except Exception as exception:
print("\nFatal while creating %s" % container_type)
print("Exception: %s %s" % (type(exception), exception))
raise
for name, user_args in sched_dict[container_type].items():
try:
instance = build_component(model, name, user_args, **extra_args)
container[name] = instance
except Exception as exception:
print("\nFatal error while parsing [section: %s] [item: %s]" % (container_type, name))
print("Exception: %s %s" % (type(exception), exception))
raise
return container
def build_component(model, name, cfg_kwargs):
cfg_kwargs['model'] = model
cfg_kwargs['name'] = name
class_ = globals()[cfg_kwargs['class']]
instance = class_(**__filter_kwargs(cfg_kwargs, class_.__init__))
def build_component(model, name, user_args, **extra_args):
# Instantiate component using the 'class' argument
class_name = user_args.pop('class')
try:
class_ = globals()[class_name]
except KeyError as ex:
raise ValueError("Class named '{0}' does not exist".format(class_name)) from ex
# First we check that the user defined dict itself does not contain invalid args
valid_args, invalid_args = __filter_kwargs(user_args, class_.__init__)
if invalid_args:
raise ValueError(
'{0} does not accept the following arguments: {1}'.format(class_name, list(invalid_args.keys())))
# Now we add some "hard-coded" args, which some classes may accept and some may not
# So then we filter again, this time ignoring any invalid args
valid_args.update(extra_args)
valid_args['model'] = model
valid_args['name'] = name
final_valid_args, _ = __filter_kwargs(valid_args, class_.__init__)
instance = class_(**final_valid_args)
return instance
def __filter_kwargs(dict_to_filter, function_to_call):
"""Utility to remove extra keyword arguments
"""Utility to check which arguments in the passed dictionary exist in a function's signature
This function will remove any unwanted kwargs and pass the rest of the kwargs
to the called function. This is needed because we want to call some existing
constructor functions, using the YAML dictionary, which contains extra parameters.
The function returns two dicts, one with just the valid args from the input and one with the invalid args.
The caller can then decide to ignore the existence of invalid args, depending on context.
"""
sig = inspect.signature(function_to_call)
filter_keys = [param.name for param in sig.parameters.values() if (param.kind == param.POSITIONAL_OR_KEYWORD)]
filtered_dict = {}
valid_args = {}
invalid_args = {}
for key in dict_to_filter:
if key in filter_keys:
filtered_dict[key] = dict_to_filter[key]
return filtered_dict
valid_args[key] = dict_to_filter[key]
else:
invalid_args[key] = dict_to_filter[key]
return valid_args, invalid_args
def __policy_params(policy_def, type):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment