Lesson 4: Under the Hood: Training a Digit Classifier
28-09-2020
This notebook will go over some of the practical material discussed in lesson 4 of the fastai 2020 course, namely, some different ways of training a digit classifier using the MNIST data set. The lesson 4 video is an extension on the lesson 3 video. There is a lot to cover...
In the last notebook we looked at some simple examples of using SGD to optimise a model. In this notebook we will apply the concepts to the MNIST problem from scratch then leter, we will refactor the code using PyTorch and fastai modules.
# imports and things we need from previous notebooks
from fastai.vision.all import *
# data
path = untar_data(URLs.MNIST_SAMPLE)
threes = (path/'train'/'3').ls().sorted()
sevens = (path/'train'/'7').ls().sorted()
seven_tensors = [tensor(Image.open(o)) for o in sevens]
three_tensors = [tensor(Image.open(o)) for o in threes]
stacked_sevens = torch.stack(seven_tensors).float()/255
stacked_threes = torch.stack(three_tensors).float()/255
valid_3_tens = torch.stack([tensor(Image.open(o))
for o in (path/'valid'/'3').ls()])
valid_3_tens = valid_3_tens.float()/255
valid_7_tens = torch.stack([tensor(Image.open(o))
for o in (path/'valid'/'7').ls()])
valid_7_tens = valid_7_tens.float()/255
def plot_function(f, tx=None, ty=None, title=None, min=-2, max=2, figsize=(6,4)):
x = torch.linspace(min,max)
fig,ax = plt.subplots(figsize=figsize)
ax.plot(x,f(x))
if tx is not None: ax.set_xlabel(tx)
if ty is not None: ax.set_ylabel(ty)
if title is not None: ax.set_title(title)
MNIST Loss function
Our X values will be pixels, we need to reshape the data using view
. We want to concatenate our x's into a single tensor, then change them from a list of matrices (a rank-3 tensor) to a list of vectors (a rank-2 tensor). Why? Because this example is meant to be simplified.
view
will return a new tensor with the same data as the original tensor but with a different shape that we define.
# concat 3s and 7s, then reshape into a matrix
# so that each row is 1 image, with all rows and columns in a single vector
train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)
# label the data
# 3 == 1
# 7 == 0
# we need this to be a matrix
# unsqueeze will do this for us
train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)
# check the shape
train_x.shape,train_y.shape
# in PyTorch we need data to be in a tuple for each row
# zip will help us with this
dset = list(zip(train_x,train_y))
# take a look at the first thing
x,y = dset[0]
x.shape, y
(torch.Size([784]), tensor([1]))
this matches what we would expect
# repeat for validation
valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
valid_dset = list(zip(valid_x,valid_y))
Now we have training and validation data sets
1. Randomly initialise weights for each pixel
- use torch.randn
to create tensor of randomly initialised weights
def init_params(size, var=1.0):
return (torch.randn(size)*var).requires_grad_()
weights = init_params((28*28,1))
weights.shape
We need to add a bias term because just using weights*pixels
will not be flexible enough. Our function will always be equal to zero when the pixels are equal to zero.
bias = init_params(1)
y = w*x+b
is the formula for a line, where w
are the weights, b
is the bias. In neural network jargon, the weights and bias will be our parameters.
This linear equation is one of the two fundamental equations of any neural network. The other is an activation function that we will see shortly.
Let's use this to calculate a prediction for one image... weights.T
will transpose the weights, this is done to make sure the rows and columns match up for our multiplication
(train_x[0]*weights.T).sum() + bias
Now we need to do this for all images. A for loop will be too slow. In PyTorch we can perform matrix multiplication using the @ operator OR by using torch.matmul()
.
# define a linear function that will
# multiple the input by weights then add a bias term
def linear1(xb): return xb@weights + bias
preds = linear1(train_x)
preds
Notice the result are the same as we just saw above. We can confirm our function is working and can also see that the operation is performed for every image in train_x
checking accuracy
- if a prediction is above the threshold, ie if > 0 then it is a 3, less than 0, 7.
- so we check if a prediction is greater than our threshold of 0, then check these against the validation set.
- this will return true when a row is correctly predicted
- we can convert these to floats using .float()
then take their mean to check overall accuracy of our randomly initialised model
threshold = 0.0
accuracy = (preds > threshold).float() == train_y
accuracy
accuracy.float().mean().item()
Let's change one of the weights by a small amount to see how accuracy is affected.
weights[0]+= 1.0001 # increase the weigh a little
preds = linear1(train_x)
accuracy2 = ((preds > threshold).float() == train_y).float().mean().item()
accuracy2
This is exactly the same as before. We have a problem, when we calculate the change, our gradient is now 0, this is because if we change a single pixel by a very small amount we might not change an actual prediction.
So because our gradient is 0, our step will be 0 which means our prediction will be unchanged.
So our accuracy loss function is not very good. A small change in our weights does not result in a small change in accuracy, so we will have zero gradients.
We need a new function that won't have a zero gradient, it needs to be more sensitive to small changes, so that a slightly better prediction needs to have a slightly better loss.
In other words, then the predictions are close to the targets the loss needs to be small, when they are far away, it needs to be big.
So let's create a new function to address this issue.
# MNIST loss
def mnist_loss(preds, targets):
return torch.where(targets==1., 1.-preds, preds).mean()
# test case
t = torch.tensor([1,0,1]) # targets
p = torch.tensor([0.9, 0.4, 0.2]) # predictions
# this is the same as mnist_loss but before the mean
torch.where(t==1, 1-p, p)
torch.where
is like list comprehension for tensors.
This function returns a lower loss when predictions are more accurate and a higher loss when they are not.
But for this to work, we need our predictions to be between 0 and 1, otherwise things do not work.
p2 = torch.tensor([1.2, -1, 0]) # predictions outside 0, 1 range
torch.where(t==1, 1-p2, p2)
The Sigmoid function
- This function will constrain our numbers between 0 and 1.
- It squashes any input in the range (-inf, inf) to some value in the range (0, 1)
def sigmoid(x) : return 1 / (1 + torch.exp(-x))
plot_function(torch.sigmoid, title='Sigmoid', min=-4, max=4)
# MNIST loss with sigmoid
def mnist_loss(predictions, targets):
preds = predictions.sigmoid()
return torch.where(targets==1., 1.-preds, preds).mean()
SGD and Mini-batches
By batching images and running computations over them is a way to compromise between speed and computational efficiency.
The size of the batch will impact your accuracy and estimates as well as the speed at which you are able to run computations. The batch size is something to be considered during training.
The DataLoader
class in pytorch helps with batching. It returns an iterator which we can loop through.
coll = range(15)
dl = DataLoader(coll, batch_size=5, shuffle=True)
list(dl)
Putting it together
# re-initialise weights and params
weights = init_params((28*28,1))
bias = init_params(1)
# create a data loader
dl = DataLoader(dset, batch_size=256)
# grab the first x and y
xb, yb = first(dl)
# check the shape
xb.shape, yb.shape
# repeat for validation set
valid_dl = DataLoader(valid_dset, batch_size=256)
# grab a mini batch to test on
batch = train_x[:4]
batch.shape
# make some predictions
preds = linear1(batch)
preds
loss = mnist_loss(preds, train_y[:4])
loss
# calculate gradients
loss.backward()
weights.grad.shape, weights.grad.mean(), bias.grad
# take those 3 steps and put it in a function
def calc_grad(xb, yb, model):
preds = model(xb)
loss = mnist_loss(preds, yb)
loss.backward()
# test it
calc_grad(batch, train_y[:4], linear1)
weights.grad.shape, weights.grad.mean(), bias.grad
# zero the gradients
weights.grad.zero_()
bias.grad.zero_()
The last step is to work out how to update the weights and bias based on the gradient and learning rate.
train_epoch
loops through the data loader, grab x batch and y batch, calculate the gradient, make a prediction and calculate the loss. Go through each parameter (weights and bias) and for each update with gradient * lr, then zero these in prep for the next loop.
p.data
is used because PyTorch keeps track of all operations so it can calculate the gradients, but we do not want the gradients to be calculated on the gradient descent step.
def train_epoch(model, lr, params):
for xb, yb in dl:
calc_grad(xb, yb, model)
for p in params:
p.data -= p.grad*lr
p.grad.zero_()
batch_accuracy
is similar to the previous loss function, but since we use a sigmoid, which constrains our preds between 0 and 1, we need to check whether preds > 0.5.
def batch_accuracy(xb, yb):
preds = xb.sigmoid()
correct = (preds>0.5) == yb # check predictions against target
return correct.float().mean()
batch_accuracy(linear1(train_x[:4]), train_y[:4])
# check accuracy for every batch in the validation set
# stack converts the list of items into tensor
def validate_epoch(model):
accs = [batch_accuracy(model(xb), yb) for xb, yb in valid_dl]
return round(torch.stack(accs).mean().item(), 4)
validate_epoch(linear1)
This is a starting point, let's train for one epoch and see if accuracy improves.
as a reminder, the linear1 function was...
- def linear1(xb): return xb@weights + bias
lr = 1.
params = weights, bias
train_epoch(linear1, lr, params)
validate_epoch(linear1)
for i in range(20):
train_epoch(linear1, lr, params)
print(validate_epoch(linear1), end=' ')
Accuracy has indeed improved! We have built an SGD optimizer that has reached about 97% accuracy.
Refactor and clean up
- create an optimiser
- use PyTorch modules and functions where available
- like
nn.Linear
- which "Applies a linear transformation to the incoming data: $y = xA^T + b$"
- like
nn.Linear?
# remove our linear function
# in place for torch module
# creates a matrix of size 28*28
# with bias of 1
linear_model = nn.Linear(28*28,1)
# check model params
w,b = linear_model.parameters()
w.shape, b.shape
Create a basic optimiser
- pass in params to optimise and lr
- store these away
- step though each param (weights and bias) and for each, update with gradient * lr
- zero the gradients in prep for the next step
class BasicOptim:
def __init__(self, params, lr):
self.params, self.lr = list(params), lr
def step(self, *args, **kwargs):
for p in self.params:
p.data -= p.grad.data * self.lr
def zero_grad(self, *args, **kwargs):
for p in self.params:
p.grad = None
# create an optimiser by passing in parameters from model
opt = BasicOptim(linear_model.parameters(), lr)
# simplify the training loop
def train_epoch(model):
for xb, yb in dl:
calc_grad(xb, yb, model)
opt.step()
opt.zero_grad()
validate_epoch(linear_model)
Now create a function train_model
that will call train_epoch
on our model for the specified number of epochs
def train_model(model, epochs):
for i in range(epochs):
train_epoch(model)
print(validate_epoch(model), end=' ')
train_model(linear_model, 20)
The results are very similar to what we have seen before.
Fastai provides SGD
that we can use instead of writing our own, again the results are very similar.
linear_model = nn.Linear(28*28, 1)
opt = SGD(linear_model.parameters(), lr)
train_model(linear_model, 20)
Let's refactor some more, using some fastai classes. The Learner
implements everything we have implemented manually.
# Previously we used DataLoader not DataLoaders
dls = DataLoaders(dl, valid_dl)
learn = Learner(dls, nn.Linear(28*28,1), opt_func=SGD,
loss_func=mnist_loss, metrics=batch_accuracy)
learn.fit(10)
The results again are very similar, but with some additional functionality (like printing out results in a pretty table!
Non-Linearity
To create a simple neural net, using a linear function like we did before is not enough. We need to add in a non-linearity between two linear functions.
This is the basic definition for a neural net..
The universal approximation theorem says, that given any arbitrarily complex continuous function, we can approximate it with a neural network. I found this useful for visualising how this works. This is what we are trying to do.
In our basic_net
, each line represents a layer in our network, the first and 3rd layers are known as linear layers the second, as a nonlinearity or an activation.
res.max(tensor(0.0))
takes the result of our linear function and sets any negative value to 0.0 while maintaining any positive values.
def basic_net(xb):
res = xb@w1 + b1
res = res.max(tensor(0.0))
res = res@w2 + b2
return res
plot_function(F.relu)
Like we have seen previously..
- w1
and w2
are weight tensors
- b1
and b2
are bias tensors
we can initialise these the same as we have done previously..
w1
has 30 output activations, so in order for w2
to match it require 30 input activations.
w1 = init_params((28*28,30))
b1 = init_params(30)
w2 = init_params((30,1))
b2 = init_params(1)
We can simplify further using PyTorch...
What we did in basic_net
was called function composition, where we passed the results of one function into another function and then into another function. This is what neural nets are doing with linear layers and activation functions. nn.Sequential()
will do this for us...
simple_net = nn.Sequential(
nn.Linear(28*28, 30), # 28*28 in, 30 out
nn.ReLU(),
nn.Linear(30,1) # 30 in 1 out
)
learn = Learner(dls, simple_net, opt_func=SGD,
loss_func=mnist_loss, metrics=batch_accuracy)
learn.fit(40,0.1)
# this is what our model now looks like
learn.model
# plot the loss
learn.recorder.plot_loss()
# learn.recorder.values hold the table values above
# lets plot the accuracy
plt.plot(L(learn.recorder.values).itemgot(2));
Looking inside...
# let's visualise some of the parameters
# 1. grab your model
m = learn.model # (0): Linear(in_features=784, out_features=30, bias=True)
# 2. look inside and grab the weights and biases
w,b = m[0]. parameters()
# 3. grab first (or any) row, reshape, and plot
show_image(w[0].view(28,28), figsize=(4,4))
fastai in full
from fastai.vision.all import *
from pathlib import Path
path = Path.cwd()/'datasets/fastai/mnist_sample'
dls = ImageDataLoaders.from_folder(path)
learn = cnn_learner(dls, resnet18, pretrained=False,
loss_func=F.cross_entropy, metrics=accuracy)
learn.fit_one_cycle(1, 0.1)
Summary
We have gone over creating and training a neural network from scratch using the simple example of a digit classifier. The key idea for the last few notebooks was to start with planning out the problem and identifying a way to solve it using a simple common sense solution - the pixel similarity model.
This proved successful but it was not really robust beyond the straightforward example we chose - identifying 3s and 7s. We then implemented a more complex solution that could be applied to more complicated problems.
After each step or concept had been implemented manually, we refactored the code to use convenient PyTorch functions and modules, eventually ending up with using fastai's implementation which abstracts away from all of the underlying heavy lifting. This is done for convenience and in my own opinion, to help lower the entry barrier into deep learning.
Ultimately I believe it is fundamentally important to understand the concepts and implementation if your goal (and this is my goal) is to implement deep learning solutions to solve business problems within your industry.