Skip to content
Snippets Groups Projects
Commit a39bb904 authored by Neta Zmora's avatar Neta Zmora
Browse files

Drop filter + other things

parent 63bea5b8
No related branches found
No related tags found
No related merge requests found
...@@ -42,6 +42,74 @@ CIFAR10_MODEL_NAMES = sorted(name for name in cifar10_models.__dict__ ...@@ -42,6 +42,74 @@ CIFAR10_MODEL_NAMES = sorted(name for name in cifar10_models.__dict__
ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES))) ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)))
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Conv2dWithMask(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dWithMask, self).__init__(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.test_mask = None
self.p_mask = 1.0
self.frequency = 16
def forward(self, input):
if self.training:
#prob = torch.distributions.binomial.Binomial(total_count=1, probs=[0.9]*self.out_channels)
#mask = prob.sample()
self.frequency -= 1
if self.frequency == 0:
sample = np.random.binomial(n=1, p=self.p_mask, size=self.out_channels)
param = self.weight
l1norm = param.detach().view(param.size(0), -1).norm(p=1, dim=1)
mask = torch.tensor(sample)
#print(mask.sum().item())
mask = mask.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous()
mask = mask.view(self.weight.shape).to(param.device)
mask = mask.type(param.type())
#print(mask.sum().item())
#pruning_factor = self.p_mask
masked_weights = self.weight * mask
masked_l1norm = masked_weights.detach().view(param.size(0), -1).norm(p=1, dim=1)
pruning_factor = (masked_l1norm.sum() / l1norm.sum()).item()
# print(pruning_factor)
pruning_factor = max(0.2, pruning_factor)
weight = masked_weights / pruning_factor
self.frequency = 16
else:
weight = self.weight
#self.test_mask = mask
# elif self.mask is not None:
# mask = self.mask.view(-1, 1, 1, 1)
# mask = mask.expand(self.weight.shape)
# mask = mask.to(self.weight.device)
# weight = self.weight * mask
else:
weight = self.weight# * self.test_mask
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
# replaces all conv2d layers in target`s model with 'Conv2dWithMask'
def replace_conv2d(container):
for name, module in container.named_children(): #for name, module in model.named_modules():
if (isinstance(module, nn.Conv2d)):
print("replacing: ", name)
new_module = Conv2dWithMask(in_channels=module.in_channels,
out_channels=module.out_channels,
kernel_size=module.kernel_size, padding=module.padding,
stride=module.stride, bias=module.bias)
setattr(container, name, new_module)
replace_conv2d(module)
def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
"""Create a pytorch model based on the model architecture and dataset """Create a pytorch model based on the model architecture and dataset
...@@ -89,4 +157,5 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): ...@@ -89,4 +157,5 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
else: else:
device = 'cpu' device = 'cpu'
#replace_conv2d(model)
return model.to(device) return model.to(device)
...@@ -199,6 +199,9 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner): ...@@ -199,6 +199,9 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
param_name, param_name,
distiller.sparsity(mask), distiller.sparsity(mask),
fraction_to_prune) fraction_to_prune)
# Compensate for dropping filters
#param.data /= float(distiller.sparsity(mask))
return binary_map return binary_map
@staticmethod @staticmethod
......
...@@ -90,6 +90,34 @@ LayerDescLen = len(LayerDesc._fields) ...@@ -90,6 +90,34 @@ LayerDescLen = len(LayerDesc._fields)
ALMOST_ONE = 0.9999 ALMOST_ONE = 0.9999
class CSVFile(object):
def __init__(self, fname, headers):
"""Create the CSV file and write the column names"""
with open(fname, 'w') as f:
writer = csv.writer(f)
writer.writerow(headers)
self.fname = fname
def add_record(self, fields):
# We close the file each time to flush on every write, and protect against data-loss on crashes
with open(self.fname, 'a') as f:
writer = csv.writer(f)
writer.writerow(fields)
class AMCStatsFile(CSVFile):
def __init__(self, fname):
headers = ['episode', 'top1', 'reward', 'total_macs', 'normalized_macs',
'normalized_nnz', 'ckpt_name', 'action_history', 'agent_action_history']
super().__init__(fname, headers)
class FineTuneStatsFile(CSVFile):
def __init__(self, fname):
headers = ['episode', 'ft_top1_list']
super().__init__(fname, headers)
def is_using_continuous_action_space(agent): def is_using_continuous_action_space(agent):
return agent in ("DDPG", "ClippedPPO-continuous", "Random-policy") return agent in ("DDPG", "ClippedPPO-continuous", "Random-policy")
...@@ -437,11 +465,14 @@ class NetworkWrapper(object): ...@@ -437,11 +465,14 @@ class NetworkWrapper(object):
optimizer = torch.optim.SGD(self.model.parameters(), lr=opt_cfg['lr'], optimizer = torch.optim.SGD(self.model.parameters(), lr=opt_cfg['lr'],
momentum=opt_cfg['momentum'], weight_decay=opt_cfg['weight_decay']) momentum=opt_cfg['momentum'], weight_decay=opt_cfg['weight_decay'])
compression_scheduler = self.create_scheduler() compression_scheduler = self.create_scheduler()
acc_list = []
for _ in range(num_epochs): for _ in range(num_epochs):
# Fine-tune the model # Fine-tune the model
self.services.train_fn(model=self.model, compression_scheduler=compression_scheduler, accuracies = self.services.train_fn(model=self.model, compression_scheduler=compression_scheduler,
optimizer=optimizer, epoch=episode) optimizer=optimizer, epoch=episode)
acc_list.extend(accuracies)
del compression_scheduler del compression_scheduler
return acc_list
class DistillerWrapperEnvironment(gym.Env): class DistillerWrapperEnvironment(gym.Env):
...@@ -476,7 +507,9 @@ class DistillerWrapperEnvironment(gym.Env): ...@@ -476,7 +507,9 @@ class DistillerWrapperEnvironment(gym.Env):
#self.observation_space = spaces.Box(0, float("inf"), shape=(self.STATE_EMBEDDING_LEN+self.num_layers(),)) #self.observation_space = spaces.Box(0, float("inf"), shape=(self.STATE_EMBEDDING_LEN+self.num_layers(),))
self.observation_space = spaces.Box(0, float("inf"), shape=(self.STATE_EMBEDDING_LEN+1,)) self.observation_space = spaces.Box(0, float("inf"), shape=(self.STATE_EMBEDDING_LEN+1,))
#self.observation_space = spaces.Box(0, float("inf"), shape=(LayerDescLen * self.num_layers(), )) #self.observation_space = spaces.Box(0, float("inf"), shape=(LayerDescLen * self.num_layers(), ))
self.create_network_record_file() #self.create_network_record_file()
self.stats_file = AMCStatsFile(os.path.join(msglogger.logdir, 'amc.csv'))
self.ft_stats_file = FineTuneStatsFile(os.path.join(msglogger.logdir, 'ft_top1.csv'))
def reset(self, init_only=False): def reset(self, init_only=False):
"""Reset the environment. """Reset the environment.
...@@ -598,8 +631,9 @@ class DistillerWrapperEnvironment(gym.Env): ...@@ -598,8 +631,9 @@ class DistillerWrapperEnvironment(gym.Env):
reward, top1, total_macs, total_nnz = self.compute_reward() reward, top1, total_macs, total_nnz = self.compute_reward()
normalized_macs = total_macs / self.dense_model_macs * 100 normalized_macs = total_macs / self.dense_model_macs * 100
normalized_nnz = total_nnz / self.dense_model_size * 100 normalized_nnz = total_nnz / self.dense_model_size * 100
self.record_network_details(top1, reward, total_macs, normalized_macs, self.finalize_episode(top1, reward, total_macs, normalized_macs,
normalized_nnz, self.action_history, self.agent_action_history) normalized_nnz, self.action_history, self.agent_action_history)
self.episode += 1 self.episode += 1
else: else:
observation = self.get_obs() observation = self.get_obs()
...@@ -715,7 +749,9 @@ class DistillerWrapperEnvironment(gym.Env): ...@@ -715,7 +749,9 @@ class DistillerWrapperEnvironment(gym.Env):
# What a hack! # What a hack!
total_nnz *= compression total_nnz *= compression
self.net_wrapper.train(self.amc_cfg.num_ft_epochs, self.episode) accuracies = self.net_wrapper.train(self.amc_cfg.num_ft_epochs, self.episode)
self.ft_stats_file.add_record([self.episode, accuracies])
top1, top5, vloss = self.net_wrapper.validate() top1, top5, vloss = self.net_wrapper.validate()
reward = self.amc_cfg.reward_fn(self, top1, top5, vloss, total_macs) reward = self.amc_cfg.reward_fn(self, top1, top5, vloss, total_macs)
...@@ -737,16 +773,8 @@ class DistillerWrapperEnvironment(gym.Env): ...@@ -737,16 +773,8 @@ class DistillerWrapperEnvironment(gym.Env):
log_freq=1, loggers=[self.tflogger, self.pylogger]) log_freq=1, loggers=[self.tflogger, self.pylogger])
return reward, top1, total_macs, total_nnz return reward, top1, total_macs, total_nnz
def create_network_record_file(self): def finalize_episode(self, top1, reward, total_macs, normalized_macs,
"""Create the CSV file and write the column names""" normalized_nnz, action_history, agent_action_history):
fields = ['episode', 'top1', 'reward', 'total_macs', 'normalized_macs',
'normalized_nnz', 'ckpt_name', 'action_history', 'agent_action_history']
with open(os.path.join(msglogger.logdir, 'amc.csv'), 'w') as f:
writer = csv.writer(f)
writer.writerow(fields)
def record_network_details(self, top1, reward, total_macs, normalized_macs,
normalized_nnz, action_history, agent_action_history):
"""Write the details of one network to a CSV file and create a checkpoint file""" """Write the details of one network to a CSV file and create a checkpoint file"""
if reward > self.best_reward: if reward > self.best_reward:
self.best_reward = reward self.best_reward = reward
...@@ -757,9 +785,7 @@ class DistillerWrapperEnvironment(gym.Env): ...@@ -757,9 +785,7 @@ class DistillerWrapperEnvironment(gym.Env):
fields = [self.episode, top1, reward, total_macs, normalized_macs, fields = [self.episode, top1, reward, total_macs, normalized_macs,
normalized_nnz, ckpt_name, action_history, agent_action_history] normalized_nnz, ckpt_name, action_history, agent_action_history]
with open(os.path.join(msglogger.logdir, 'amc.csv'), 'a') as f: self.stats_file.add_record(fields)
writer = csv.writer(f)
writer.writerow(fields)
def save_checkpoint(self, is_best=False): def save_checkpoint(self, is_best=False):
"""Save the learned-model checkpoint""" """Save the learned-model checkpoint"""
...@@ -856,7 +882,7 @@ def sample_networks(net_wrapper, services): ...@@ -856,7 +882,7 @@ def sample_networks(net_wrapper, services):
sparsity_level = min(max(0, sparsity_level), ALMOST_ONE) sparsity_level = min(max(0, sparsity_level), ALMOST_ONE)
net_wrapper.remove_structures(layer_id, net_wrapper.remove_structures(layer_id,
fraction_to_prune=sparsity_level, fraction_to_prune=sparsity_level,
prune_what="filters") prune_what="channels")
net_wrapper.train(1) net_wrapper.train(1)
top1, top5, vloss = net_wrapper.validate() top1, top5, vloss = net_wrapper.validate()
......
...@@ -2,7 +2,7 @@ lr_schedulers: ...@@ -2,7 +2,7 @@ lr_schedulers:
training_lr: training_lr:
class: StepLR # ReduceLROnPlateau class: StepLR # ReduceLROnPlateau
step_size: 10 step_size: 10
gamma: 0.3 gamma: 0.2
#mode: max #mode: max
#patience: 5 #patience: 5
#factor: 0.1 #factor: 0.1
......
...@@ -149,13 +149,13 @@ ...@@ -149,13 +149,13 @@
lr_schedulers: lr_schedulers:
training_lr: training_lr:
class: StepLR class: MultiStepLR
step_size: 45 milestones: [60, 120, 160]
gamma: 0.10 gamma: 0.20
policies: policies:
- lr_scheduler: - lr_scheduler:
instance_name: training_lr instance_name: training_lr
starting_epoch: 45 starting_epoch: 0
ending_epoch: 200 ending_epoch: 200
frequency: 1 frequency: 1
Source diff could not be displayed: it is too large. Options to address this: view the blob.
...@@ -257,6 +257,18 @@ def main(): ...@@ -257,6 +257,18 @@ def main():
# Train for one epoch # Train for one epoch
with collectors_context(activations_collectors["train"]) as collectors: with collectors_context(activations_collectors["train"]) as collectors:
# if epoch > 15:
# for name, module in model.named_modules():
# if (isinstance(module, nn.Conv2d)):
# module.p_mask = max(0.6, module.p_mask-0.005)
# #module.p_mask = max(0.5, module.p_mask-0.02)
# msglogger.info("setting filter drop probability to %.2f", module.p_mask)
train(train_loader, model, criterion, optimizer, epoch, compression_scheduler, train(train_loader, model, criterion, optimizer, epoch, compression_scheduler,
loggers=[tflogger, pylogger], args=args) loggers=[tflogger, pylogger], args=args)
distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger]) distiller.log_weights_sparsity(model, epoch, loggers=[tflogger, pylogger])
...@@ -327,8 +339,8 @@ def train(train_loader, model, criterion, optimizer, epoch, ...@@ -327,8 +339,8 @@ def train(train_loader, model, criterion, optimizer, epoch,
# Switch to train mode # Switch to train mode
model.train() model.train()
acc_stats = []
end = time.time() end = time.time()
for train_step, (inputs, target) in enumerate(train_loader): for train_step, (inputs, target) in enumerate(train_loader):
# Measure data loading time # Measure data loading time
data_time.add(time.time() - end) data_time.add(time.time() - end)
...@@ -345,12 +357,13 @@ def train(train_loader, model, criterion, optimizer, epoch, ...@@ -345,12 +357,13 @@ def train(train_loader, model, criterion, optimizer, epoch,
if not args.earlyexit_lossweights: if not args.earlyexit_lossweights:
loss = criterion(output, target) loss = criterion(output, target)
# Measure accuracy and record loss # Measure accuracy
classerr.add(output.data, target) classerr.add(output.data, target)
acc_stats.append([classerr.value(1), classerr.value(5)])
else: else:
# Measure accuracy and record loss # Measure accuracy and record loss
loss = earlyexit_loss(output, target, criterion, args) loss = earlyexit_loss(output, target, criterion, args)
# Record loss
losses[OBJECTIVE_LOSS_KEY].add(loss.item()) losses[OBJECTIVE_LOSS_KEY].add(loss.item())
if compression_scheduler: if compression_scheduler:
...@@ -406,6 +419,7 @@ def train(train_loader, model, criterion, optimizer, epoch, ...@@ -406,6 +419,7 @@ def train(train_loader, model, criterion, optimizer, epoch,
steps_per_epoch, args.print_freq, steps_per_epoch, args.print_freq,
loggers) loggers)
end = time.time() end = time.time()
return acc_stats
def validate(val_loader, model, criterion, loggers, args, epoch=-1): def validate(val_loader, model, criterion, loggers, args, epoch=-1):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment