diff --git a/distiller/config.py b/distiller/config.py index 4fad434f347477b32394e93a5e4540384e5327c4..c0fb982efc0b9469a187f3f665b3380b0db7a149 100755 --- a/distiller/config.py +++ b/distiller/config.py @@ -157,13 +157,13 @@ def config_component_from_file_by_class(model, filename, class_name, **extra_arg config_dict = distiller.utils.yaml_ordered_load(stream) config_dict.pop('policies', None) for section_name, components in config_dict.items(): - for component_name, component_args in components.items(): - if component_args['class'] == class_name: + for component_name, user_args in components.items(): + if user_args['class'] == class_name: msglogger.info( 'Found component of class {0}: Name: {1} ; Section: {2}'.format(class_name, component_name, section_name)) - component_args.update(extra_args) - return build_component(model, component_name, component_args) + user_args.update(extra_args) + return build_component(model, component_name, user_args) raise ValueError( 'Component of class {0} does not exist in configuration file {1}'.format(class_name, filename)) except yaml.YAMLError: @@ -171,55 +171,63 @@ def config_component_from_file_by_class(model, filename, class_name, **extra_arg raise -def __factory(container_type, model, sched_dict, **kwargs): +def __factory(container_type, model, sched_dict, **extra_args): container = {} if container_type in sched_dict: - try: - for name, cfg_kwargs in sched_dict[container_type].items(): - try: - cfg_kwargs.update(kwargs) - # Instantiate pruners using the 'class' argument - instance = build_component(model, name, cfg_kwargs) - container[name] = instance - except NameError as error: - print("\nFatal error while parsing [section:%s] [item:%s]" % (container_type, name)) - raise - except Exception as exception: - print("\nFatal error while parsing [section:%s] [item:%s]" % (container_type, name)) - print("Exception: %s %s" % (type(exception), exception)) - raise - except Exception as exception: - print("\nFatal while creating %s" % container_type) - print("Exception: %s %s" % (type(exception), exception)) - raise + for name, user_args in sched_dict[container_type].items(): + try: + instance = build_component(model, name, user_args, **extra_args) + container[name] = instance + except Exception as exception: + print("\nFatal error while parsing [section: %s] [item: %s]" % (container_type, name)) + print("Exception: %s %s" % (type(exception), exception)) + raise return container -def build_component(model, name, cfg_kwargs): - cfg_kwargs['model'] = model - cfg_kwargs['name'] = name - class_ = globals()[cfg_kwargs['class']] - instance = class_(**__filter_kwargs(cfg_kwargs, class_.__init__)) +def build_component(model, name, user_args, **extra_args): + # Instantiate component using the 'class' argument + class_name = user_args.pop('class') + try: + class_ = globals()[class_name] + except KeyError as ex: + raise ValueError("Class named '{0}' does not exist".format(class_name)) from ex + + # First we check that the user defined dict itself does not contain invalid args + valid_args, invalid_args = __filter_kwargs(user_args, class_.__init__) + if invalid_args: + raise ValueError( + '{0} does not accept the following arguments: {1}'.format(class_name, list(invalid_args.keys()))) + + # Now we add some "hard-coded" args, which some classes may accept and some may not + # So then we filter again, this time ignoring any invalid args + valid_args.update(extra_args) + valid_args['model'] = model + valid_args['name'] = name + final_valid_args, _ = __filter_kwargs(valid_args, class_.__init__) + instance = class_(**final_valid_args) return instance def __filter_kwargs(dict_to_filter, function_to_call): - """Utility to remove extra keyword arguments + """Utility to check which arguments in the passed dictionary exist in a function's signature - This function will remove any unwanted kwargs and pass the rest of the kwargs - to the called function. This is needed because we want to call some existing - constructor functions, using the YAML dictionary, which contains extra parameters. + The function returns two dicts, one with just the valid args from the input and one with the invalid args. + The caller can then decide to ignore the existence of invalid args, depending on context. """ sig = inspect.signature(function_to_call) filter_keys = [param.name for param in sig.parameters.values() if (param.kind == param.POSITIONAL_OR_KEYWORD)] - filtered_dict = {} + valid_args = {} + invalid_args = {} for key in dict_to_filter: if key in filter_keys: - filtered_dict[key] = dict_to_filter[key] - return filtered_dict + valid_args[key] = dict_to_filter[key] + else: + invalid_args[key] = dict_to_filter[key] + return valid_args, invalid_args def __policy_params(policy_def, type):