Skip to content
Snippets Groups Projects
Commit 1a64c13b authored by Neta Zmora's avatar Neta Zmora
Browse files

Fix issue #396

Add to the notebook a missing function.
parent 61dfc828
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
## Compare experiment executions
This notebook let's you qickly compare the training progress of your experiments.
This notebook lets you quickly compare the training progress of your experiments, from within a notebook. For other use-cases we advise to use TensorBoard which is equipped with many more features than these few lines of code.
You will need to have the tfevents files (these are TensorBoard formatted log files that Distiller creates).
%% Cell type:code id: tags:
``` python
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def get_tags_list(path_to_events_file):
tags = [v.tag for e in tf.train.summary_iterator(path_to_events_file) for v in e.summary.value]
return set(tags)
def get_performance_data(path_to_events_file, tag):
"""Extract the performance history of data named 'tag'
Based on sample code from TF:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/summary/summary_iterator.py
"""
data = []
steps = []
for e in tf.train.summary_iterator(path_to_events_file):
for v in e.summary.value:
if v.tag == tag:
data.append(v.simple_value)
steps.append(e.step)
return steps, data
def add_experiment(axs, tags, results, label):
for i, tag in enumerate(tags):
steps, prec1_ssl = get_performance_data(results, tag)
axs[i//2, i%2].plot(steps, prec1_ssl, label=label)
axs[i//2][i%2].set_title(tag)
axs[i//2][i%2].legend()
```
%% Cell type:code id: tags:
``` python
# Here insert your own tfevents files to compare
# WARNING: these files do not exist in the repositroy (too large) and will give you an error
experiment_files = [('events.out.tfevents.1523290172.one-machine', 'experiment 1'),
('events.out.tfevents.1520430112.one-machine', 'experiment 2')]
# Choose which performance indicators you wish to graph
tags = ['Peformance/Validation/Top1', 'Peformance/Validation/Loss',
'sparsity/weights/total', 'Peformance/Training/Reg Loss']
f, axs = plt.subplots(2, 2, figsize=(20,20))
f.suptitle('Performance')
print(get_tags_list(experiment_files[0][0]))
for experiment in experiment_files:
add_experiment(axs, tags, experiment[0], label=experiment[1])
plt.tight_layout()
plt.show()
```
%% Cell type:code id: tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment