Skip to content
Snippets Groups Projects
Commit b9bf4282 authored by Neta Zmora's avatar Neta Zmora
Browse files

change the access to the data of 0-dimensional tensors

parent 7ce11aee
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Automated Gradual Pruning Schedule # Automated Gradual Pruning Schedule
Michael Zhu and Suyog Gupta, ["To prune, or not to prune: exploring the efficacy of pruning for model compression"](https://arxiv.org/pdf/1710.01878), 2017 NIPS Workshop on Machine Learning of Phones and other Consumer Devices<br> Michael Zhu and Suyog Gupta, ["To prune, or not to prune: exploring the efficacy of pruning for model compression"](https://arxiv.org/pdf/1710.01878), 2017 NIPS Workshop on Machine Learning of Phones and other Consumer Devices<br>
<br> <br>
After completing sensitivity analysis, decide on your pruning schedule. After completing sensitivity analysis, decide on your pruning schedule.
## Table of Contents ## Table of Contents
1. [Implementation of the gradual sparsity function](#Implementation-of-the-gradual-sparsity-function) 1. [Implementation of the gradual sparsity function](#Implementation-of-the-gradual-sparsity-function)
2. [Visualize pruning schedule](#Visualize-pruning-schedule) 2. [Visualize pruning schedule](#Visualize-pruning-schedule)
3. [References](#References) 3. [References](#References)
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import numpy import numpy
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from functools import partial from functools import partial
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
from ipywidgets import widgets, interact from ipywidgets import widgets, interact
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Implementation of the gradual sparsity function ## Implementation of the gradual sparsity function
The function ```sparsity_target``` implements the gradual sparsity schedule from [[1]](#zhu-gupta):<br><br> The function ```sparsity_target``` implements the gradual sparsity schedule from [[1]](#zhu-gupta):<br><br>
<b><i>"We introduce a new automated gradual pruning algorithm in which the sparsity is increased from an initial sparsity value $s_i$ (usually 0) to a final sparsity value $s_f$ over a span of $n$ pruning steps, starting at training step $t_0$ and with pruning frequency $\Delta t$."</i></b><br> <b><i>"We introduce a new automated gradual pruning algorithm in which the sparsity is increased from an initial sparsity value $s_i$ (usually 0) to a final sparsity value $s_f$ over a span of $n$ pruning steps, starting at training step $t_0$ and with pruning frequency $\Delta t$."</i></b><br>
<br> <br>
<div id="eq:zhu_gupta_schedule"></div> <div id="eq:zhu_gupta_schedule"></div>
<center> <center>
$\large $\large
\begin{align} \begin{align}
s_t = s_f + (s_i - s_f) \left(1- \frac{t-t_0}{n\Delta t}\right)^3 s_t = s_f + (s_i - s_f) \left(1- \frac{t-t_0}{n\Delta t}\right)^3
\end{align} \end{align}
\ \ for \ \ for
\large \ \ t \in \{t_0, t_0+\Delta t, ..., t_0+n\Delta t\} \large \ \ t \in \{t_0, t_0+\Delta t, ..., t_0+n\Delta t\}
$ $
</center> </center>
<br> <br>
Pruning happens once at the beginning of each epoch, until the duration of the pruning (the number of epochs to prune) is exceeded. After pruning ends, the training continues without pruning, but the pruned weights are kept at zero. Pruning happens once at the beginning of each epoch, until the duration of the pruning (the number of epochs to prune) is exceeded. After pruning ends, the training continues without pruning, but the pruned weights are kept at zero.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def sparsity_target(starting_epoch, ending_epoch, initial_sparsity, final_sparsity, current_epoch): def sparsity_target(starting_epoch, ending_epoch, initial_sparsity, final_sparsity, current_epoch):
if final_sparsity < initial_sparsity: if final_sparsity < initial_sparsity:
return current_epoch return current_epoch
if current_epoch < starting_epoch: if current_epoch < starting_epoch:
return current_epoch return current_epoch
span = ending_epoch - starting_epoch span = ending_epoch - starting_epoch
target_sparsity = ( final_sparsity + target_sparsity = ( final_sparsity +
(initial_sparsity - final_sparsity) * (initial_sparsity - final_sparsity) *
(1.0 - ((current_epoch-starting_epoch)/span))**3) (1.0 - ((current_epoch-starting_epoch)/span))**3)
return target_sparsity return target_sparsity
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Visualize pruning schedule ## Visualize pruning schedule
When using the Automated Gradual Pruning (AGP) schedule, you may want to visualize how the pruning schedule will look as a function of the epoch number. This is called the *sparsity function*. The widget below will help you do this.<br> When using the Automated Gradual Pruning (AGP) schedule, you may want to visualize how the pruning schedule will look as a function of the epoch number. This is called the *sparsity function*. The widget below will help you do this.<br>
There are three knobs you can use to change the schedule: There are three knobs you can use to change the schedule:
- ```duration```: this is the number of epochs over which to use the AGP schedule ($n\Delta t$). - ```duration```: this is the number of epochs over which to use the AGP schedule ($n\Delta t$).
- ```initial_sparsity```: $s_i$ - ```initial_sparsity```: $s_i$
- ```final_sparsity```: $s_f$ - ```final_sparsity```: $s_f$
- ```frequency```: this is the pruning frequency ($\Delta t$). - ```frequency```: this is the pruning frequency ($\Delta t$).
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def draw_pruning(duration, initial_sparsity, final_sparsity, frequency): def draw_pruning(duration, initial_sparsity, final_sparsity, frequency):
epochs = [] epochs = []
sparsity_levels = [] sparsity_levels = []
# The derivative of the sparsity (i.e. sparsity rate of change) # The derivative of the sparsity (i.e. sparsity rate of change)
d_sparsity = [] d_sparsity = []
if frequency=='': if frequency=='':
frequency = 1 frequency = 1
else: else:
frequency = int(frequency) frequency = int(frequency)
for epoch in range(0,40): for epoch in range(0,40):
epochs.append(epoch) epochs.append(epoch)
current_epoch=Variable(torch.FloatTensor([epoch]), requires_grad=True) current_epoch=Variable(torch.FloatTensor([epoch]), requires_grad=True)
if epoch<duration and epoch%frequency == 0: if epoch<duration and epoch%frequency == 0:
sparsity = sparsity_target( sparsity = sparsity_target(
starting_epoch=0, starting_epoch=0,
ending_epoch=duration, ending_epoch=duration,
initial_sparsity=initial_sparsity, initial_sparsity=initial_sparsity,
final_sparsity=final_sparsity, final_sparsity=final_sparsity,
current_epoch=current_epoch current_epoch=current_epoch
) )
sparsity_levels.append(sparsity) sparsity_levels.append(sparsity)
sparsity.backward() sparsity.backward()
d_sparsity.append(current_epoch.grad.data[0]) d_sparsity.append(current_epoch.grad.item())
current_epoch.grad.data.zero_() current_epoch.grad.data.zero_()
else: else:
sparsity_levels.append(sparsity) sparsity_levels.append(sparsity)
d_sparsity.append(0) d_sparsity.append(0)
plt.plot(epochs, sparsity_levels, epochs, d_sparsity) plt.plot(epochs, sparsity_levels, epochs, d_sparsity)
plt.ylabel('sparsity (%)') plt.ylabel('sparsity (%)')
plt.xlabel('epoch') plt.xlabel('epoch')
plt.title('Pruning Rate') plt.title('Pruning Rate')
plt.ylim(0, 100) plt.ylim(0, 100)
plt.draw() plt.draw()
duration_widget = widgets.IntSlider(min=0, max=100, step=1, value=28) duration_widget = widgets.IntSlider(min=0, max=100, step=1, value=28)
si_widget = widgets.IntSlider(min=0, max=100, step=1, value=0) si_widget = widgets.IntSlider(min=0, max=100, step=1, value=0)
interact(draw_pruning, interact(draw_pruning,
duration=duration_widget, duration=duration_widget,
initial_sparsity=si_widget, initial_sparsity=si_widget,
final_sparsity=(0,100,1), final_sparsity=(0,100,1),
frequency='2'); frequency='2');
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
<div id="toc"></div> <div id="toc"></div>
## References ## References
1. <div id="zhu-gupta"></div> **Michael Zhu and Suyog Gupta**. 1. <div id="zhu-gupta"></div> **Michael Zhu and Suyog Gupta**.
[*To prune, or not to prune: exploring the efficacy of pruning for model compression*](https://arxiv.org/pdf/1710.01878), [*To prune, or not to prune: exploring the efficacy of pruning for model compression*](https://arxiv.org/pdf/1710.01878),
NIPS Workshop on Machine Learning of Phones and other Consumer Devices, NIPS Workshop on Machine Learning of Phones and other Consumer Devices,
2017. 2017.
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment