Convolutional Filter Factorization¶
In this notebook, I want to see if we can factorize a spatial filter with large support into successive convolutions of filters with smaller spatial support. The inspiration for this comes from the fact that I often have to work with large spatial kernels, and there may be some speedup available by using successive smaller convolutions.
For example lots of 5x5 kernels can be made up of 2 3x3 convolutions, using 18 multiply/adds per pixel instead of 25. Of course, there are some that cannot be made up from 2 3x3 convolutions, but perhaps we are happy to pay a small error term.
Further, when viewing nvidia-prof profiling outputs, it seems that if we use 3x3 convolutions, these are done using winograd rather than fft/ifft lowlevel kernels. I wouldn’t be surprised if due to the large prevalence of 3x3 kernels in the machine learning literature, that GPUs are being specifically optimized for them.
In [12]:
# Import my plotting libraries. plotters is available at
# https://github.com/fbcotter/plotters
%matplotlib notebook
import matplotlib as mpl
mpl.rcParams['image.resample'] = False
mpl.rcParams['image.interpolation'] = 'none'
mpl.rcParams['image.cmap'] = 'gray'
import matplotlib.pyplot as plt
import plotters
# Import my numeric libraries. Pytorch wavelets is a library that allows
# the dtcwt kernels (the particular ones I want to speed up) to be done in
# pytorch (and optionally, with a GPU).
import numpy as np
import torch
import torch.nn.functional as F
import dtcwt
from pytorch_wavelets import DTCWTForward, DTCWTInverse
from pytorch_wavelets.dtcwt.lowlevel import q2c
dev = torch.device('cpu')
Get the target filters¶
This next bit of code is largely specific to my task. Any filters here would suffice. For our example, I’m looking at the synthesis filters of a single layer inverse Dual Tree Complex Wavelet Transform. They are 7x7 but have little energy in the outermost pixels.
In [2]:
ifm = DTCWTInverse(J=1, biort='near_sym_a', qshift='qshift_a')
def yb_to_yh(yb):
# Converts the 6 complex bands back into the 3 LH, HH, HL bands
deg15, deg165 = q2c(yb[:,0])
deg45, deg135 = q2c(yb[:,1])
deg75, deg105 = q2c(yb[:,2])
yh = torch.stack((deg15, deg45, deg75, deg105, deg135, deg165), dim=2)
return yh
# Create an output array to store the results
xb = np.zeros((4,16,16))
yl = torch.zeros(1,1,16,16)
yb = torch.zeros(1,3,1,16,16)
# Put impulses in the ll, lh, hh and hl bands
yl[0,0,8,8] = 1
x = ifm((yl, [yb_to_yh(yb)]))
xb[0] = x.detach().cpu().numpy()[0,0]
yl[0,0,8,8] = 0
for b in range(3):
yb[0,b,0,8,8] = 1
x = ifm((yl, [yb_to_yh(yb)]))
xb[b+1] = x.detach().cpu().numpy()[0,0]
yb[0,b] = 0
# Plot the resulting outputs
fig, ax = plt.subplots(1,4, figsize=(9,3),
subplot_kw={'xticks':[], 'yticks':[]},
gridspec_kw={'top':0.95, 'bottom':0.05, 'left':0.05,
'right':0.95, 'wspace':0.05})
for i in range(4):
ax[i].imshow(xb[i])
fig.suptitle('ll, lh, hh, hl impulse responses - our 7x7 target filters');
Aside 1: Convex vs Non Convex¶
This is a non convex optimization problem. Consider the simple case where we want to minimize a 1x1 kernel as a convolution of 2 1x1 kernels (I know this is silly, but it makes the maths nicer). Then our output is:
To check for convexity, we need the Hessian matrix to be positive semi definite. We know that whenever there is a cross term then the problem cannot be convex, but for simplicity sake, let us check what the Hessian matrix is anyway.
Positive semi definite means that for any \(x\), \(x^THx\) must be greater than or equal to zero for any x. In our above problem, our Hessian is larger than 2 by 2, but there is a nice trick in that every subset of the Hessian must also be PSD (you can set the corresponding parts of \(x\) to 0 to ignore the other areas). In the above case, if \(x=[-1,1]^T\) gives \(x^THx = -2\).
Another nice bit of info for PSD, is that if a matrix is diagonally dominant, i.e.
and \(h_{i,i} \geq 0\), then we have a PSD matrix. This however is sufficient and not necessary. This is often why we sometimes add a small identity matrix to a covariance matrix to make sure we are satisfying its PSD conditions.
Aside 2: check how to do alternating projections¶
Non convex optimization is not trivial. A simple way of solving it is to do alternating projections. I.e. minimize w.r.t. x keeping y constant, the minimizing w.r.t. y keeping x constant. Let us return to the \(f=xy\) problem, and try to use pytorch to alternate minimizations. This is an underdetermined set of equations, so let us also add a constraint on the norms of \(x\) and \(y\). I.e. minimize:
As an example, let us try to find two numbers \(x\) and \(y\) that multiply to give 20, while minimising the \(l_1\) norm.
In [3]:
# Randomly initialize x and y to be a gaussian with std 10
x = torch.tensor(10*np.random.randn(1), dtype=torch.float32, requires_grad=True)
y = torch.tensor(10*np.random.randn(1), dtype=torch.float32, requires_grad=True)
# Set the target (α), the lagrangian multiplier (λ) and the learning rate
α = 20
λ = 0.1
lr = 1e-2
N = 20 # Number of steps
# Create two optimizers, one for x and one for y
optimx = torch.optim.SGD([x], lr=lr, momentum=0)
optimy = torch.optim.SGD([y], lr=lr, momentum=0)
for i in range(N):
# Minimize wrt x
optimx.zero_grad()
optimy.zero_grad()
z = x*y
l = (z - α)**2 + λ * (torch.abs(x) + torch.abs(y))
print("Step {:2}a: Loss = {loss:09.4f}\tOut = {out:.2f} ({x:.2f}, {y:.2f})".format(
i, loss=l.item(), out=z.item(), x=x.item(), y=y.item()))
l.backward()
optimx.step()
# Minimize wrt y
optimx.zero_grad()
optimy.zero_grad()
z = x*y
l = (z - α)**2 + λ * (torch.abs(x) + torch.abs(y))
print("Step {:2}b: Loss = {loss:09.4f}\tOut = {out:.2f} ({x:.2f}, {y:.2f})".format(
i, loss=l.item(), out=z.item(), x=x.item(), y=y.item()))
l.backward()
optimy.step()
Step 0a: Loss = 0404.3896 Out = -0.10 (0.04, -2.68)
Step 0b: Loss = 0296.7857 Out = 2.78 (-1.04, -2.68)
Step 1a: Loss = 0284.1869 Out = 3.15 (-1.04, -3.03)
Step 1b: Loss = 0189.4539 Out = 6.25 (-2.06, -3.03)
Step 2a: Loss = 0158.8276 Out = 7.42 (-2.06, -3.60)
Step 2b: Loss = 0087.5617 Out = 10.68 (-2.97, -3.60)
Step 3a: Loss = 0059.7771 Out = 12.31 (-2.97, -4.15)
Step 3b: Loss = 0026.1703 Out = 14.96 (-3.60, -4.15)
Step 4a: Loss = 0014.7589 Out = 16.27 (-3.60, -4.51)
Step 4b: Loss = 0005.7598 Out = 17.78 (-3.94, -4.51)
Step 5a: Loss = 0003.2123 Out = 18.47 (-3.94, -4.69)
Step 5b: Loss = 0001.6230 Out = 19.14 (-4.08, -4.69)
Step 6a: Loss = 0001.2203 Out = 19.42 (-4.08, -4.76)
Step 6b: Loss = 0000.9932 Out = 19.68 (-4.14, -4.76)
Step 7a: Loss = 0000.9386 Out = 19.78 (-4.14, -4.78)
Step 7b: Loss = 0000.9088 Out = 19.88 (-4.16, -4.78)
Step 8a: Loss = 0000.9019 Out = 19.92 (-4.16, -4.79)
Step 8b: Loss = 0000.8980 Out = 19.95 (-4.16, -4.79)
Step 9a: Loss = 0000.8972 Out = 19.96 (-4.16, -4.80)
Step 9b: Loss = 0000.8967 Out = 19.98 (-4.17, -4.80)
Step 10a: Loss = 0000.8966 Out = 19.98 (-4.17, -4.80)
Step 10b: Loss = 0000.8965 Out = 19.98 (-4.17, -4.80)
Step 11a: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 11b: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 12a: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 12b: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 13a: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 13b: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 14a: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 14b: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 15a: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 15b: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 16a: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 16b: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 17a: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 17b: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 18a: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 18b: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 19a: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
Step 19b: Loss = 0000.8965 Out = 19.99 (-4.17, -4.80)
This seems to work nicely. Notice as we step from \(n\)a to \(n\)b, only one of the two parameters changes.
Now onwards to the filter factorization!
Define filters as convolution of 2 3x3¶
Let us try match the lh, hl and hh wavelets. To make our filters admissable wavelets, we also need to impose the restriction that \(\sum_{i,j} w_{i,j} = 0\)
In [4]:
# Define our
w1 = torch.randn(1,1,3,3, requires_grad=True)
w2 = torch.randn(1,1,3,3, requires_grad=True)
# Define the lagrangian for mse, minimum energy coefficients and admissability
λ = 100
λ1 = .1
λ2 = 1
lr = 5e-2
N = 20
# Define our target - we'll start with lh
y = torch.tensor(xb[1], dtype=torch.float32).reshape(1,1,16,16)
# Make it zero mean and unit std
y = (y-y.mean())/y.std()
# Define our two optimizers
optim1 = torch.optim.SGD([w1], lr=lr, momentum=0)
optim2 = torch.optim.SGD([w2], lr=lr, momentum=0)
x = torch.zeros(1,1,16,16)
x[0,0,8,8] = 1
lbest = np.inf
w1_best = w1.data.numpy()
w2_best = w2.data.numpy()
for i in range(N):
# Minimize wrt w1
optim1.zero_grad()
optim2.zero_grad()
y1 = F.conv2d(x, w1, padding=1)
y2 = F.conv2d(y1, w2, padding=1)
lmse = λ * torch.mean((y - y2)**2)
lreg = λ1 * (torch.mean(w1**2) + torch.mean(w2**2))
ladmiss = λ2 * torch.abs(torch.sum(y2))
l = lmse + lreg + ladmiss
print("Step {:2}a: Loss = {loss:09.4f}\tMSE: {mse:07.3f}\t "
"Reg: {reg:07.3f}\tAdmiss: {adm:07.3f})".format(
i, loss=l.item(), mse=lmse.item(), reg=lreg.item(),
adm=ladmiss.item()))
l.backward()
optim1.step()
if l.item() < lbest:
w1_best = w1.data.numpy()
w2_best = w2.data.numpy()
lbest = l.item()
# Minimize wrt w2
optim1.zero_grad()
optim2.zero_grad()
y1 = F.conv2d(x, w1, padding=1)
y2 = F.conv2d(y1, w2, padding=1)
lmse = λ * torch.mean((y - y2)**2)
lreg = λ1 * (torch.mean(w1**2) + torch.mean(w2**2))
ladmiss = λ2 * torch.abs(torch.sum(y2))
l = lmse + lreg + ladmiss
print("Step {:2}b: Loss = {loss:09.4f}\tMSE: {mse:07.3f}\t "
"Reg: {reg:07.3f}\tAdmiss: {adm:07.3f})".format(
i, loss=l.item(), mse=lmse.item(), reg=lreg.item(),
adm=ladmiss.item()))
l.backward()
optim2.step()
if l.item() < lbest:
w1_best = w1.data.numpy()
w2_best = w2.data.numpy()
lbest = l.item()
print('-' * 50)
print('Best loss was {}'.format(lbest))
Step 0a: Loss = 0104.7647 MSE: 103.354 Reg: 000.168 Admiss: 001.243)
Step 0b: Loss = 0043.8679 MSE: 043.400 Reg: 000.180 Admiss: 000.288)
Step 1a: Loss = 0022.3645 MSE: 021.798 Reg: 000.226 Admiss: 000.341)
Step 1b: Loss = 0009.5339 MSE: 008.788 Reg: 000.252 Admiss: 000.494)
Step 2a: Loss = 0006.9480 MSE: 006.337 Reg: 000.260 Admiss: 000.351)
Step 2b: Loss = 0003.7069 MSE: 003.242 Reg: 000.263 Admiss: 000.201)
Step 3a: Loss = 0002.5059 MSE: 002.113 Reg: 000.263 Admiss: 000.129)
Step 3b: Loss = 0001.6599 MSE: 001.254 Reg: 000.265 Admiss: 000.141)
Step 4a: Loss = 0001.2352 MSE: 000.890 Reg: 000.266 Admiss: 000.079)
Step 4b: Loss = 0000.9242 MSE: 000.644 Reg: 000.267 Admiss: 000.014)
Step 5a: Loss = 0000.7898 MSE: 000.515 Reg: 000.265 Admiss: 000.009)
Step 5b: Loss = 0000.7030 MSE: 000.419 Reg: 000.265 Admiss: 000.019)
Step 6a: Loss = 0000.6378 MSE: 000.362 Reg: 000.265 Admiss: 000.010)
Step 6b: Loss = 0000.5727 MSE: 000.305 Reg: 000.265 Admiss: 000.003)
Step 7a: Loss = 0000.5402 MSE: 000.274 Reg: 000.264 Admiss: 000.002)
Step 7b: Loss = 0000.5080 MSE: 000.243 Reg: 000.264 Admiss: 000.001)
Step 8a: Loss = 0000.4896 MSE: 000.225 Reg: 000.264 Admiss: 000.000)
Step 8b: Loss = 0000.4690 MSE: 000.204 Reg: 000.264 Admiss: 000.001)
Step 9a: Loss = 0000.4556 MSE: 000.192 Reg: 000.264 Admiss: 000.000)
Step 9b: Loss = 0000.4416 MSE: 000.178 Reg: 000.264 Admiss: 000.000)
Step 10a: Loss = 0000.4330 MSE: 000.169 Reg: 000.263 Admiss: 000.000)
Step 10b: Loss = 0000.4227 MSE: 000.159 Reg: 000.263 Admiss: 000.000)
Step 11a: Loss = 0000.4164 MSE: 000.153 Reg: 000.263 Admiss: 000.000)
Step 11b: Loss = 0000.4086 MSE: 000.145 Reg: 000.263 Admiss: 000.000)
Step 12a: Loss = 0000.4035 MSE: 000.140 Reg: 000.263 Admiss: 000.000)
Step 12b: Loss = 0000.3980 MSE: 000.135 Reg: 000.263 Admiss: 000.000)
Step 13a: Loss = 0000.3939 MSE: 000.131 Reg: 000.263 Admiss: 000.000)
Step 13b: Loss = 0000.3890 MSE: 000.126 Reg: 000.263 Admiss: 000.000)
Step 14a: Loss = 0000.3857 MSE: 000.123 Reg: 000.263 Admiss: 000.000)
Step 14b: Loss = 0000.3818 MSE: 000.119 Reg: 000.263 Admiss: 000.000)
Step 15a: Loss = 0000.3789 MSE: 000.116 Reg: 000.263 Admiss: 000.000)
Step 15b: Loss = 0000.3762 MSE: 000.113 Reg: 000.263 Admiss: 000.001)
Step 16a: Loss = 0000.3738 MSE: 000.111 Reg: 000.262 Admiss: 000.001)
Step 16b: Loss = 0000.3707 MSE: 000.108 Reg: 000.262 Admiss: 000.000)
Step 17a: Loss = 0000.3686 MSE: 000.106 Reg: 000.262 Admiss: 000.000)
Step 17b: Loss = 0000.3667 MSE: 000.104 Reg: 000.262 Admiss: 000.000)
Step 18a: Loss = 0000.3649 MSE: 000.102 Reg: 000.262 Admiss: 000.000)
Step 18b: Loss = 0000.3632 MSE: 000.100 Reg: 000.262 Admiss: 000.001)
Step 19a: Loss = 0000.3617 MSE: 000.099 Reg: 000.262 Admiss: 000.001)
Step 19b: Loss = 0000.3596 MSE: 000.097 Reg: 000.262 Admiss: 000.000)
--------------------------------------------------
Best loss was 0.3595961928367615
Plot the result
In [5]:
y1 = F.conv2d(x, torch.tensor(w1_best, dtype=torch.float32), padding=1)
y2 = F.conv2d(y1, torch.tensor(w2_best, dtype=torch.float32), padding=1)
y_hat = y2.detach().cpu().numpy()[0,0]
y_target = y.detach().cpu().numpy()[0,0]
fig, ax = plt.subplots(1,3, figsize=(9,4),
subplot_kw={'xticks':[], 'yticks':[]})
ax[0].imshow(y_target, vmin=-5, vmax=5)
ax[0].set_title('Target filter (7x7)')
ax[1].imshow(y_hat, vmin=-5, vmax=5)
ax[1].set_title('Learned factorization\nwith 2 3x3 kernels')
ax[2].imshow(y_target-y_hat, vmin=-.5, vmax=0.5)
im = ax[2].set_title('Error\n(scaled up by factor of 10)')
fig.suptitle('MSE of final solution: {:.2e}'.format(
np.mean((y_target - y_hat)**2)));
Run with different initial conditions¶
This is good, but we want to run it with different random starts and find the best over these
In [6]:
def get_loss(x, y, net, λ, λ1, λ2):
y_out = net(x)
lmse = λ*torch.mean((y - y_out)**2)
lreg = λ1 * net.reg_loss()
ladmiss = λ2 * torch.sum(y_out)**2
l = lmse + lreg + ladmiss
return l
class MyNet(torch.nn.Module):
def __init__(self, k=3, n=2):
super().__init__()
self.w = torch.nn.ParameterList([
torch.nn.Parameter(torch.randn(1,1,k,k)) for _ in range(n)])
self.n = n
self.k = k
def reg_loss(self):
loss = torch.tensor(0., dtype=torch.float)
for i in range(self.n):
loss += torch.mean(self.w[i]**2)
return loss
def accept_weights(self, ws):
self.w = torch.nn.ParameterList([
torch.nn.Parameter(
torch.tensor(w, dtype=torch.float32), requires_grad=False)
for w in ws
])
self.n = len(ws)
self.k = ws[0].shape[-1]
def forward(self, X):
for i in range(self.n):
X = F.conv2d(X, self.w[i], padding=(self.k//2))
return X
def find_ws(x, target, k=3, n=2, runs=50, iters=50, λ=5000, λ1=0.1, λ2=1, lr=5e-3):
""" Matches a target filter by convolving n smaller filters
Args:
x: the impulse input
target: target filter
k: kernel size of smaller filter
n: number of successive smaller filters to use
runs: number of random starts
iters: number of steps per random start
λ: MSE lagrange multiplier
λ1: Regularization lagrange multiplier
λ2: Admissibility lagrange multiplier
"""
if np.abs(target.sum()) > 1e-6 and λ2 > 0:
print('Warning: the admissibility gain is non zero but the target is '
'itself not admissible')
bestw = [np.zeros((1,1,k,k), dtype='float32') for _ in range(n)]
y = torch.tensor(target, dtype=torch.float32, device=dev)
if len(y.shape) == 2:
y = torch.unsqueeze(torch.unsqueeze(y, dim=0), dim=0)
# Make the target 0 mean and unit std
y = (y-y.mean())/y.std()
x = torch.tensor(x, dtype=torch.float32, device=dev)
if len(x.shape) == 2:
x = torch.unsqueeze(torch.unsqueeze(x, dim=0), dim=0)
lbest = np.inf
for N in range(runs):
try:
net = MyNet(k, n)
optim = [torch.optim.SGD([net.w[i]], lr=lr, momentum=0) for i in range(n)]
lbest_run = np.inf
bestw_run = [np.zeros((1,1,k,k), dtype='float32') for _ in range(n)]
for i in range(iters):
for j in range(n):
[o.zero_grad() for o in optim]
l = get_loss(x, y, net, λ, λ1, λ2)
l.backward(retain_graph=True)
optim[j].step()
#if torch.isnan(net.w[j]).sum() > 0:
# raise RuntimeError
if l.item() < lbest_run:
for p in range(n):
bestw_run[p][:] = net.w[p].detach().data.cpu().numpy()[:]
lbest_run = l.item()
if lbest_run < lbest:
for p in range(n):
bestw[p][:] = bestw_run[p][:]
lbest = lbest_run
except RuntimeError:
# Nan error raised
print('Nan was raised in iter {}. Continuing...'.format(N))
continue
return lbest, bestw
In [7]:
impulse16 = torch.zeros((1,1,16,16))
impulse16[0,0,8,8] = 1
lbest, bestw = find_ws(x=impulse16,
target=xb[1],
k=3,
n=2,
iters=100,
λ=100, λ1=0.1, λ2=1, lr=5e-2)
In [8]:
# Get the output by passing impulse to network
net = MyNet()
net.accept_weights(bestw)
y_hat = net(impulse16).detach().cpu().numpy()[0,0]
y_target = torch.tensor(xb[1], dtype=torch.float32, device=dev)
y_target = ((y_target-y_target.mean())/y_target.std()).cpu().numpy()
# Plot
fig, ax = plt.subplots(1,3, figsize=(9,4),
subplot_kw={'xticks':[], 'yticks':[]})
ax[0].imshow(y_target, vmin=-5, vmax=5)
ax[1].imshow(y_hat, vmin=-5, vmax=5)
ax[2].imshow(y_target-y_hat, vmin=-.5, vmax=0.5)
fig.suptitle('MSE of final solution: {:.2e}'.format(
np.mean((y_target - y_hat)**2)));
Compare to parameterization as 1 5x5¶
We know this should be pretty much perfect, so how does the above compare to a 5x5?
In [9]:
lbest, bestw_1_5x5_lh = find_ws(impulse16, xb[1], k=5, n=1, iters=100)
# Get the output by passing impulse to network
x = torch.zeros(1,1,16,16)
x[0,0,8,8] = 1
for w in bestw_1_5x5_lh:
x = F.conv2d(x, torch.tensor(w), padding=w.shape[-1]//2)
y_hat = x.detach().cpu().numpy()[0,0]
fig, ax = plt.subplots(1,3, figsize=(9,4),
subplot_kw={'xticks':[], 'yticks':[]})
ax[0].imshow(y_target, vmin=-5, vmax=5)
ax[1].imshow(y_hat, vmin=-5, vmax=5)
ax[2].imshow(y_target-y_hat, vmin=-.5, vmax=0.5)
fig.suptitle('MSE of final solution: {:.2e}'.format(
np.mean((y_target - y_hat)**2)));
Compare to 3 3x3 kernels¶
In [10]:
lbest, bestw_3_3x3_lh = find_ws(impulse16, xb[1], k=3, n=3, iters=100, λ=500)
# Get the output by passing impulse to network
net = MyNet()
net.accept_weights(bestw_3_3x3_lh)
y_hat = net(impulse16).detach().cpu().numpy()[0,0]
y_target = torch.tensor(xb[1], dtype=torch.float32, device=dev)
y_target = ((y_target-y_target.mean())/y_target.std()).cpu().numpy()
# Plot
fig, ax = plt.subplots(1,3, figsize=(9,4),
subplot_kw={'xticks':[], 'yticks':[]})
ax[0].imshow(y_target, vmin=-5, vmax=5)
ax[1].imshow(y_hat, vmin=-5, vmax=5)
ax[2].imshow(y_target-y_hat, vmin=-.5, vmax=0.5)
fig.suptitle('MSE of final solution: {:.2e}'.format(
np.mean((y_target - y_hat)**2)));
Try a different filter¶
For finality, let us also try the ll filter. Again we’ll go back to the design with 2 3x3 filters instead of 1 5x5 or 3 3x3’s
In [11]:
lbest, bestw_2_3x3_ll = find_ws(impulse16, xb[0], k=3, n=2, iters=100, λ=1000, λ2=0)
# Get the output by passing impulse to network
net = MyNet()
net.accept_weights(bestw_2_3x3_ll)
y_hat = net(impulse16).detach().cpu().numpy()[0,0]
y_target = torch.tensor(xb[0], dtype=torch.float32, device=dev)
y_target = ((y_target-y_target.mean())/y_target.std()).cpu().numpy()
# Plot
fig, ax = plt.subplots(1,3, figsize=(9,4),
subplot_kw={'xticks':[], 'yticks':[]})
ax[0].imshow(y_target, vmin=-5, vmax=5)
ax[1].imshow(y_hat, vmin=-5, vmax=5)
ax[2].imshow(y_target-y_hat, vmin=-.5, vmax=0.5)
fig.suptitle('MSE of final solution: {:.2e}'.format(
np.mean((y_target - y_hat)**2)));
Conclusion¶
We have shown how it is possible to factorize some larger filters as convolutions of smaller filters. Note that we haven’t discussed a lot of the hyperparameter search done for the different weight factors between MSE loss, l2 loss and admissibility as well as the learning rate. These are not necessarily easy to find, and some optimization may need to be done if repeating on your own filters.