diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 26bedf5a49d33474ea27a5500796ccf6d95b1261..c35c31b516639f5c439b07bd3ddedbe028de98f8 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -79,11 +79,7 @@ from distiller.data_loggers import TensorBoardLogger, PythonLogger, ActivationSp import distiller.quantization as quantization from models import ALL_MODEL_NAMES, create_model - msglogger = None -log_filename = '' - - parser = argparse.ArgumentParser(description='Distiller image classification model compression') parser.add_argument('data', metavar='DIR', help='path to dataset') parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', @@ -134,22 +130,26 @@ parser.add_argument('--gpus', metavar='DEV_ID', default=None, parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name') -def main(): - - args = parser.parse_args() - +def config_logger(experiment_name): # The Distiller library writes logs to the Python logger, so we configure it. - global msglogger timestr = time.strftime("%Y.%m.%d-%H%M%S") - filename = timestr if args.name is None else args.name + '___' + timestr + filename = timestr if experiment_name is None else experiment_name + '___' + timestr logdir = './logs' + '/' + filename if not os.path.exists(logdir): os.makedirs(logdir) log_filename = os.path.join(logdir, filename + '.log') logging.config.fileConfig(os.path.join(script_dir, 'logging.conf'), defaults={'logfilename': log_filename}) msglogger = logging.getLogger() - + msglogger.logdir = logdir + msglogger.log_filename = log_filename msglogger.info('Log file for this run: ' + os.path.realpath(log_filename)) + return msglogger + + +def main(): + global msglogger + args = parser.parse_args() + msglogger = config_logger(args.name) # Log various details about the execution environment. It is sometimes useful # to refer to past experiment executions and this information may be useful. @@ -202,7 +202,7 @@ def main(): compression_scheduler = None # Create a couple of logging backends. TensorBoardLogger writes log files in a format # that can be read by Google's Tensor Board. PythonLogger writes to the Python logger. - tflogger = TensorBoardLogger(logdir) + tflogger = TensorBoardLogger(msglogger.logdir) pylogger = PythonLogger(msglogger) # We can optionally resume from a checkpoint @@ -506,4 +506,4 @@ if __name__ == '__main__': finally: if msglogger is not None: msglogger.info('') - msglogger.info('Log file for this run: ' + os.path.realpath(log_filename)) + msglogger.info('Log file for this run: ' + os.path.realpath(msglogger.log_filename))