ReLU Analysis¶
In this notebook we take two off the shelf neural networks (VGG16 and ResNet-18) and analyse the role the ReLU plays in it. VGG16 is famous for being a simple yet effective neural network. It consists of repeated \(3\times 3\) units and downsampling every now and then. Ever since then, most networks mostly use \(3\times 3\) convolutions (and even \(1 \times 1\) convolutions).
Most of the work I do, with the \(\mathrm{DT}\mathbb{C}\mathrm{WT}\) as a building block, has large spatial sizes. Given that the current trend is in the opposite direction, we need to determine whether using \(3\times 3\) convolutions is done simply for easier learning, or because it is better to have small convolutions interspersed with nonlinearities.
What does the VGG paper itself have to say about the choice of their layer? The discussion section says the following on the issue:
Rather than using relatively large receptive fields in the first conv layers (e.g. \(11\times 11\) with stride 4 in Krizhevsky et al., 2012 , or \(7\times 7\) with stride 2 in Zeiler & Fergus 2013 or Sermanet et al., 2014), we use very small \(3 \times 3\) receptive fields throughout the whole net,which are convolved with the input at every pixel (with stride 1)…
So what have we gained by using, for instance, a stack of three \(3\times 3\) conv. layers instead of a single \(7\times 7\) layer? First, we incorporate three non-linear rectification layers instead of a single one, which makes the decision function more discriminative. Second, we decrease the number of parameters: assuming that both the input and the output of a three-layer \(3\times 3\) convolution stack has \(C\) channels, the stack is parametrised by \(3(3^2 C^2) = 27C^2\) weights; at the same time, a single \(7 \times 7\) conv. layer would require \(7^2C^2 = 49C^2\) parameters, i.e. \(81\%\) more. This can be seen as imposing a regularisation on the \(7 \times 7\) conv. filters, forcing them to have a decomposition through the \(3\times 3\) filters (with non-linearity injected in between)
Code Imports¶
In [1]:
%matplotlib notebook
# Import plotting functions
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import plotters
# Import data crunching libraries
import numpy as np
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torchvision.models as models
from torchvision.models import vgg, resnet
# py3nvml is a library to mask out the CUDA_VISIBLE_DEVICES env variable
# so that only one gpu is taken (for multi-gpu machines). Available on pypi
import py3nvml
py3nvml.grab_gpus(1)
# Import data loading functions
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
# misc others
import os
import time
from collections import OrderedDict
import types
# This is the base directory for where my imagenet is stored. you only need
# the validation set for this experiment. Can do with any appropriate folder
# of images
basedir = '/scratch/share/ImageNet2017/Data/CLS-LOC/val'
Create a pytorch data loader using torchvision¶
In [2]:
batch_size = 16
num_workers = 1
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
val_dataset = datasets.ImageFolder(
basedir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True)
VGG¶
Load vgg16 in using torchvision¶
We will need to slightly modify the torchvision models vgg16 for the purposes of this experiment
In [3]:
vgg16_base = vgg.vgg16(pretrained=True)
def prettify_vgg16():
# The code below simply distills vgg16 to another net with names.
class Vectorize(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
# Set the feature and classifier names
fnames = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'max1',
'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'max2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'max3',
'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'max4',
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'max5']
cnames = ['fc1', 'relu1', 'drop1', 'fc2', 'relu2', 'drop2', 'fc3']
# Create a list of modules for the feature extractor and classifier
# We don't want the default in-place relu, so we have to change it
feature_tups = [(fnames[i], nn.ReLU() if isinstance(f, nn.ReLU) else f)
for i, f in enumerate(vgg16_base.features)]
classifier_tups = [('reshape', Vectorize())] + \
[(cnames[i], vgg16_base.classifier[i])
for i, c in enumerate(vgg16_base.classifier)]
# Combine into a sequentail net
vgg16 = nn.Sequential(OrderedDict([
('features', nn.Sequential(OrderedDict([*feature_tups,]))),
('classifier', nn.Sequential(OrderedDict([*classifier_tups])))
]))
return vgg16
# Print the new vgg with names
vgg16 = prettify_vgg16()
vgg16.cuda()
vgg16
Out[3]:
Sequential(
(features): Sequential(
(conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu1_1): ReLU()
(conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu1_2): ReLU()
(max1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu2_1): ReLU()
(conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu2_2): ReLU()
(max2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv3_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu3_1): ReLU()
(conv3_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu3_2): ReLU()
(conv3_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu3_3): ReLU()
(max3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv4_1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu4_1): ReLU()
(conv4_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu4_2): ReLU()
(conv4_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu4_3): ReLU()
(max4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv5_1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu5_1): ReLU()
(conv5_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu5_2): ReLU()
(conv5_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu5_3): ReLU()
(max5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(reshape): Vectorize()
(fc1): Linear(in_features=25088, out_features=4096, bias=True)
(relu1): ReLU(inplace)
(drop1): Dropout(p=0.5)
(fc2): Linear(in_features=4096, out_features=4096, bias=True)
(relu2): ReLU(inplace)
(drop2): Dropout(p=0.5)
(fc3): Linear(in_features=4096, out_features=1000, bias=True)
)
)
Validate with the validation set¶
Optional - is just nice to confirm we have the correct network. Will take some time. We will need the validate function for later purposes, but can skip the cell below actually running the data through
In [4]:
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def validate(val_loader, model, criterion, print_freq=500, max_iters=-1):
batches = max_iters if max_iters >= 0 else len(val_loader)
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
for i, (input, target) in enumerate(val_loader):
input = input.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))
if i % print_freq == 0:
l = len(str(len(val_loader)))
print('Test: [{0:0{2}}/{1}]\t'
'Time {time:.1f} min\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, batches, l, time=(time.time()-end)/60,
loss=losses, top1=top1, top5=top5))
end = time.time()
if i == max_iters:
break
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg
In [5]:
validate(val_loader, vgg16, criterion = torch.nn.CrossEntropyLoss().cuda())
Test: [0000/3125] Time 0.0 min Loss 1.1812 (1.1812) Prec@1 87.500 (87.500) Prec@5 87.500 (87.500)
Test: [0500/3125] Time 1.3 min Loss 0.1959 (0.7829) Prec@1 93.750 (80.140) Prec@5 100.000 (94.237)
Test: [1000/3125] Time 1.3 min Loss 0.2300 (0.8404) Prec@1 93.750 (77.785) Prec@5 100.000 (94.524)
Test: [1500/3125] Time 1.3 min Loss 1.1483 (0.9238) Prec@1 68.750 (76.428) Prec@5 93.750 (93.359)
Test: [2000/3125] Time 1.2 min Loss 0.4954 (1.0410) Prec@1 81.250 (73.769) Prec@5 100.000 (91.826)
Test: [2500/3125] Time 1.2 min Loss 0.2583 (1.1114) Prec@1 93.750 (72.451) Prec@5 100.000 (90.744)
Test: [3000/3125] Time 1.2 min Loss 0.9678 (1.1425) Prec@1 81.250 (71.693) Prec@5 93.750 (90.409)
* Prec@1 71.592 Prec@5 90.382
Out[5]:
71.592
Look at the distribution of activations as they enter the ReLUs¶
For each relu, let us look at the distribution of activations just before it. We can do this by registering a forward hook on each of them so as the data is passed through, we call the hook which can do some processing and binning.
In [6]:
def get_counts(self, input, output, relu_num=0):
""" Hook to add relu info to a global variable """
# input is a tuple of packed inputs and output is a Tensor.
# I believe the hook gets called after the module, so having inplace
# relus meant that input=output.
global bins, counts
# If the bins haven't been set already, set them with numpy
if np.abs(bins[relu_num]).sum() == 0:
counts[relu_num], bins[relu_num] = \
np.histogram(input[0].data.detach().cpu(), bins=N_bins)
else:
c, _ = np.histogram(input[0].data.detach().cpu(),
bins=bins[relu_num])
counts[relu_num] += c
Create arrays to store info for all the relus in vgg16 and register forward hooks for each of them
In [7]:
# Create a new network with hooks attached to it
vgg16_hooks = prettify_vgg16()
vgg16_hooks.cuda()
# Create np arrays to store the binned activations
N_relus = 13
N_bins = 200
counts = np.zeros((N_relus, N_bins), dtype=np.int64)
# Set the histogram bin values
# Getting good histogram bins is not an easy thing to do, I have made
# the following bins by experience
bins = np.stack((np.linspace(-15, 15, N_bins+1),
np.linspace(-15, 15, N_bins+1),
np.linspace(-30, 30, N_bins+1),
np.linspace(-30, 30, N_bins+1),
np.linspace(-50, 50, N_bins+1),
np.linspace(-50, 50, N_bins+1),
np.linspace(-50, 50, N_bins+1),
np.linspace(-70, 40, N_bins+1),
np.linspace(-70, 40, N_bins+1),
np.linspace(-70, 40, N_bins+1),
np.linspace(-50, 50, N_bins+1),
np.linspace(-50, 50, N_bins+1),
np.linspace(-50, 50, N_bins+1),
), axis=0)
# Add the hooks to the layers
hooks = [None,] * N_relus
# Can't do the following in a loop as the lambda doesn't get called until the
# process runs, so the relu_num parameter is read later
hooks[0] = vgg16_hooks.features.relu1_1.register_forward_hook(lambda x,y,z: get_counts(x,y,z,0))
hooks[1] = vgg16_hooks.features.relu1_2.register_forward_hook(lambda x,y,z: get_counts(x,y,z,1))
hooks[2] = vgg16_hooks.features.relu2_1.register_forward_hook(lambda x,y,z: get_counts(x,y,z,2))
hooks[3] = vgg16_hooks.features.relu2_2.register_forward_hook(lambda x,y,z: get_counts(x,y,z,3))
hooks[4] = vgg16_hooks.features.relu3_1.register_forward_hook(lambda x,y,z: get_counts(x,y,z,4))
hooks[5] = vgg16_hooks.features.relu3_2.register_forward_hook(lambda x,y,z: get_counts(x,y,z,5))
hooks[6] = vgg16_hooks.features.relu3_3.register_forward_hook(lambda x,y,z: get_counts(x,y,z,6))
hooks[7] = vgg16_hooks.features.relu4_1.register_forward_hook(lambda x,y,z: get_counts(x,y,z,7))
hooks[8] = vgg16_hooks.features.relu4_2.register_forward_hook(lambda x,y,z: get_counts(x,y,z,8))
hooks[9] = vgg16_hooks.features.relu4_3.register_forward_hook(lambda x,y,z: get_counts(x,y,z,9))
hooks[10] = vgg16_hooks.features.relu5_1.register_forward_hook(lambda x,y,z: get_counts(x,y,z,10))
hooks[11] = vgg16_hooks.features.relu5_2.register_forward_hook(lambda x,y,z: get_counts(x,y,z,11))
hooks[12] = vgg16_hooks.features.relu5_3.register_forward_hook(lambda x,y,z: get_counts(x,y,z,12))
This is costly to do for more than a few thousand images, so we run a few batches from the validation set and get the histogram counts
In [8]:
validate(val_loader, vgg16_hooks, torch.nn.CrossEntropyLoss().cuda(),
print_freq=5, max_iters=20)
vgg_counts = counts
np.save('vgg_relu_bincouts', vgg_counts)
Test: [0000/20] Time 0.3 min Loss 1.1812 (1.1812) Prec@1 87.500 (87.500) Prec@5 87.500 (87.500)
Test: [0005/20] Time 1.4 min Loss 0.4964 (0.5429) Prec@1 93.750 (91.667) Prec@5 93.750 (94.792)
Test: [0010/20] Time 1.4 min Loss 1.0755 (0.5784) Prec@1 75.000 (88.068) Prec@5 87.500 (95.455)
Test: [0015/20] Time 1.4 min Loss 0.5989 (0.5876) Prec@1 87.500 (86.719) Prec@5 100.000 (96.484)
Test: [0020/20] Time 1.4 min Loss 0.9810 (0.6854) Prec@1 68.750 (83.929) Prec@5 100.000 (95.833)
* Prec@1 83.929 Prec@5 95.833
plot the resulting histograms
In [9]:
names = ['conv1_1', 'conv1_2',
'conv2_1', 'conv2_2',
'conv3_1', 'conv3_2', 'conv3_3',
'conv4_1', 'conv4_2', 'conv4_3',
'conv5_1', 'conv5_2', 'conv5_3']
fig = plt.figure(figsize=(10,14))
gs = gridspec.GridSpec(5, 3)
gs.update(left=0.05, right=0.95, wspace=0.2, hspace=0.22, top=0.98, bottom=0.1)
for i,n in enumerate(names):
idx = n.split('conv')[1]
row, col = idx.split('_')
row, col = int(row)-1, int(col)-1
ax = plt.subplot(gs[row,col])
x = bins[i,:-1]
y = counts[i,:].astype('float64')
# Normalize to unit area
bw = x[1] - x[0]
y = y/(y.sum() * bw)
ax.bar(x, y, width=bw)
count = np.sum(y)
lesszero = np.sum(y[x<0])
ax.set_title('{} - {:.1f}% below zero'.format(n, 100*lesszero/count), fontsize=10)
This is very interesting, the ReLU is doing a lot at each layer, but seems to be doing more in the deeper layers rejecting as much as 90% of the activations.
ResNet¶
In [10]:
def prettify_resnet():
# Resnet is already pretty (with nice names), but we do need to replace
# the inplace relus again. And annoyingly, the models version uses the
# same relu twice, so we have to overload the forward function
resnet18 = resnet.resnet18(pretrained=True)
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu2(out)
return out
# Change the relus and the forward functions
n_relus = 1
resnet18.relu = nn.ReLU()
for l in [resnet18.layer1, resnet18.layer2, resnet18.layer3]:
for blk in l:
n_relus += 2
blk.relu = nn.ReLU()
blk.relu2 = nn.ReLU()
blk.forward = types.MethodType(forward, blk)
print('There are {} relus'.format(n_relus))
return resnet18
In [11]:
res18 = prettify_resnet()
validate(val_loader, res18.cuda(), criterion = torch.nn.CrossEntropyLoss().cuda(),
max_iters=500, print_freq=100)
There are 13 relus
Test: [0000/500] Time 0.0 min Loss 1.2258 (1.2258) Prec@1 81.250 (81.250) Prec@5 93.750 (93.750)
Test: [0100/500] Time 0.3 min Loss 2.3805 (0.6449) Prec@1 37.500 (84.035) Prec@5 81.250 (95.111)
Test: [0200/500] Time 0.3 min Loss 1.9602 (0.9610) Prec@1 43.750 (75.715) Prec@5 81.250 (92.289)
Test: [0300/500] Time 0.2 min Loss 0.1376 (0.9172) Prec@1 87.500 (76.495) Prec@5 100.000 (92.650)
Test: [0400/500] Time 0.3 min Loss 0.4400 (0.9384) Prec@1 87.500 (76.278) Prec@5 100.000 (92.269)
Test: [0500/500] Time 0.2 min Loss 0.4888 (0.8715) Prec@1 81.250 (77.807) Prec@5 93.750 (92.889)
* Prec@1 77.807 Prec@5 92.889
Out[11]:
77.80688622754491
In [12]:
# Create a new network with hooks attached to it
res18_hooks = prettify_resnet()
res18_hooks.cuda()
# Create np arrays to store the binned activations
N_relus = 13
N_bins = 200
counts = np.zeros((N_relus, N_bins), dtype=np.int64)
# Set the histogram bin values
# Getting good histogram bins is not an easy thing to do, I have made
# the following bins by experience
bins = np.repeat(np.expand_dims(np.linspace(-3, 3, N_bins+1), axis=0), 13, axis=0)
# Add the hooks to the layers
hooks = [None,] * N_relus
# Can't do the following in a loop as the lambda doesn't get called until the
# process runs, so the relu_num parameter is read later
hooks[0] = res18_hooks.relu.register_forward_hook(lambda x,y,z: get_counts(x,y,z,0))
hooks[1] = res18_hooks.layer1[0].relu.register_forward_hook(lambda x,y,z: get_counts(x,y,z,1))
hooks[2] = res18_hooks.layer1[0].relu2.register_forward_hook(lambda x,y,z: get_counts(x,y,z,2))
hooks[3] = res18_hooks.layer1[1].relu.register_forward_hook(lambda x,y,z: get_counts(x,y,z,3))
hooks[4] = res18_hooks.layer1[1].relu2.register_forward_hook(lambda x,y,z: get_counts(x,y,z,4))
hooks[5] = res18_hooks.layer2[0].relu.register_forward_hook(lambda x,y,z: get_counts(x,y,z,5))
hooks[6] = res18_hooks.layer2[0].relu2.register_forward_hook(lambda x,y,z: get_counts(x,y,z,6))
hooks[7] = res18_hooks.layer2[1].relu.register_forward_hook(lambda x,y,z: get_counts(x,y,z,7))
hooks[8] = res18_hooks.layer2[1].relu2.register_forward_hook(lambda x,y,z: get_counts(x,y,z,8))
hooks[9] = res18_hooks.layer3[0].relu.register_forward_hook(lambda x,y,z: get_counts(x,y,z,9))
hooks[10] = res18_hooks.layer3[0].relu2.register_forward_hook(lambda x,y,z: get_counts(x,y,z,10))
hooks[11] = res18_hooks.layer3[1].relu2.register_forward_hook(lambda x,y,z: get_counts(x,y,z,11))
hooks[12] = res18_hooks.layer3[1].relu2.register_forward_hook(lambda x,y,z: get_counts(x,y,z,12))
There are 13 relus
In [13]:
validate(val_loader, res18_hooks, torch.nn.CrossEntropyLoss().cuda(),
print_freq=5, max_iters=20)
resnet_counts = counts
np.save('resnet_relu_bincouts', resnet_counts)
Test: [0000/20] Time 0.1 min Loss 1.2258 (1.2258) Prec@1 81.250 (81.250) Prec@5 93.750 (93.750)
Test: [0005/20] Time 0.2 min Loss 0.3355 (0.5991) Prec@1 81.250 (84.375) Prec@5 100.000 (95.833)
Test: [0010/20] Time 0.2 min Loss 1.1851 (0.6938) Prec@1 68.750 (80.682) Prec@5 87.500 (95.455)
Test: [0015/20] Time 0.2 min Loss 0.4824 (0.6744) Prec@1 93.750 (80.078) Prec@5 100.000 (96.094)
Test: [0020/20] Time 0.2 min Loss 0.8000 (0.7893) Prec@1 68.750 (77.679) Prec@5 100.000 (94.643)
* Prec@1 77.679 Prec@5 94.643
In [14]:
names = ['relu1',
'layer1_0_r1', 'layer1_0_r2',
'layer1_1_r1', 'layer1_1_r2',
'layer2_0_r1', 'layer2_0_r2',
'layer2_1_r1', 'layer2_1_r2',
'layer3_0_r1', 'layer3_0_r2',
'layer3_1_r1', 'layer3_1_r2']
fig = plt.figure(figsize=(10,11))
gs = gridspec.GridSpec(4, 4)
gs.update(left=0.05, right=0.95, wspace=0.2, hspace=0.22, top=0.98, bottom=0.1)
for i, n in enumerate(names):
row = (i - 1) // 4 + 1
col = (i - 1) % 4 - 3*(i == 0)
ax = plt.subplot(gs[row,col])
x = bins[i,:-1]
y = counts[i,:].astype('float64')
# Normalize to unit area
bw = x[1] - x[0]
y = y/(y.sum() * bw)
ax.bar(x, y, width=bw)
count = np.sum(y)
lesszero = np.sum(y[x<0])
ax.set_title('{} - {:.1f}% below zero'.format(n, 100*lesszero/count), fontsize=10)
Again this is quite interesting but tells a different story to vgg16. For the resnet, there is a less pronounced sparsifying effect as we go deeper in the network. It is still above 50% but not quite as large as the 80-90% we were seeing with vgg16