Skip to content

Lesson 3: Under the Hood: Training a Digit Classifier

26-09-2020

This notebook will go over some of the practical material discussed in lesson 3 of the fastai 2020 course, namely, some different ways of training a digit classifier using the MNIST data set.

In part 1 we used a simple model that had no learned components. In this notebook, we will explore a smarter solution. We will apply this method back to the MNIST problem in the next notebook.

Stocastic Gradient Descent (SGD)

Instead of measuring how close something is to an "ideal" image, we could find a set of weights for each pixel, the highest weights will be associated with pixels that are most likely to be black for a particular category (in our case, number).

This can be represented by a function and a set of weight values for each possible category - ie the probability of a category being an 8.

def prob_eight(x,w) = (x*w)sum()

Here x is a vector that represents an image (with all rows stacked up) and w is a vector of weights.

With this function, we now just need a way to gradually update the weights to make them better and better until they are as good as they can get.

In other words, we want to find the specific values of w that will cause the result of our function to be high when passed images of 8s and low for other digits. So by updating w we are optimising the function to recognise 8s.

The steps we will follow are: 1. Initialize the weights. - start out with a random guess. 2. For each image, use these weights to predict whether it appears to be a 3 or a 7. 3. Based on these predictions, calculate how good the model is (its loss). 4. Calculate the gradient, which measures for each weight, how changing that weight would change the loss. 5. Step (that is, update) all the weights based on that calculation. 6. Go back to the step 2, repeat the process. 7. Iterate until you decide to stop the training process (for instance, because the model is good enough or you don't want to wait any longer).

source

We will use a very simple example for illustration purposes.

from fastai.vision.all import *
from utils import *
# define a quadratic function 

def f(x): return x**2

I had some trouble finding the plot_function function so found the source code from the forums.

def plot_function(f, tx=None, ty=None, title=None, min=-2, max=2, figsize=(6,4)):
    x = torch.linspace(min,max)
    fig,ax = plt.subplots(figsize=figsize)
    ax.plot(x,f(x))
    if tx is not None: ax.set_xlabel(tx)
    if ty is not None: ax.set_ylabel(ty)
    if title is not None: ax.set_title(title)
# plot that function
plot_function(f, 'x', 'x**2')

# start with a random value for a parameter
plt.scatter(-1.5, f(-1.5), color='red');

Now we need to see what would happen if we increase or decrease our parameter by a small amout. The goal, to find the lowest point in the curve. We do this by calculating the gradient at a particular point. We can change our weight by a small amount in the direction of the slope, then calculate the loss, make an adjustment and repeat until we reach our goal.

Calculating Gradients

The slope of a line can be described as the rate of change of a verticle variable with respect to a horizontal variabe. This is the gradient. By calculating the gradient, it will tell us how much we have to change each weight to make our model better.

A Derivative is the instantaneous rate of change at a particular point. So how much is y changing with respect to x at that point. You can achieve this by calculating the slope of a tangent line. This video provides a good explanation of the concept.

We can calculate the derivative for any function. For the quadratic above, the derivative is another function that calculates change, rather than the value. If we know how our function changes at a particular value, then we know how to minimize it.

"This is the key to machine learning: having a way to change the parameters of a function to make it smaller. Calculus provides us with a computational shortcut, the derivative, which lets us directly calculate the gradients of our functions" source

PyTorch helps us do this using vector calculus. eg.

# define a vector xt
# requires_grad_ lets OyTorch know we need to calculate gradients

xt = tensor([3.,4.,10.]).requires_grad_()
xt
tensor([ 3.,  4., 10.], requires_grad=True)
# define a function x that takes a vector (rank 1 tensor)
# and returns a scalar (rank 0 tensor)
# do this by summing the result of x**2

def f(x): return (x**2).sum()

yt = f(xt)
yt
tensor(125., grad_fn=<SumBackward0>)

calling backward refers to back propagation, which is the process of calculating the derivative of each layer. We can then use .grad to view the gradients.

We can confirm that the derivative of x**2 is 2*x

yt.backward()
xt.grad
tensor([ 6.,  8., 20.])

Stepping with a learning rate

Now that we know how to calculate the slope of our function, that tells us if we change our input a little bit, how will our out change correspondingly.

let's call the weights w to update them we do the following

w -= gradient(w) * lr

So update w by subtracting the gradient of w multiplied by the learning rate lr. This process is known as, stepping your parameters, using and optimiser step.

Defining a good learning rate is one of the key principles in machine learning. An intuitive way to think about it is, if the lr is too small, it will take you forever to reach the goal, if it is too big, you will over shoot that target.

End-to-end Gradient Descent example.

Use gradient descent to see how finding a minimum can be used to train a model to fir better data..

Let's measure the speed of a rollercoaster as it goes over a rise - starting fast, then slowing down at the peak, then speeding up again.

If we measured the speed at 20 second intervals it might look something like this.

time = torch.arange(0,20).float()

# speed is a quadratic with some added noise
speed = torch.randn(20)*3 + 0.75*(time-9.5)**2 + 1

plt.scatter(time,speed);

We need to create a function that estimates at any time, the speed of the rollercoaster.

Start with a random guess, here let's use a quadratic

a*(time**2) + (b*time) + c

# here the input t is time, params will be a list a,b,c

def f(t, params):
    a,b,c = params
    return a*(t**2) + (b*t) + c

So the goal is to find some function, or the best imaginable function that fits the data (we have simplified to finding the best quadratic function which is defined by the params a, b and c). So finding the best f can be achieved by finding the best a, b and c values.

We will need a loss function for this task.

def mse(preds, targets): return ((preds-targets)**2).mean()

Implement the 7 steps

1. Initialize the weights (params) to random values. Let PyTorch know that we will need to calculate the gradients.

params = torch.randn(3).requires_grad_()
orig_params = params.clone() # save these to check later

2. Calculate the predictions using our function

Then check how close/far the predictions are from the targets.

preds = f(time, params)
def show_preds(preds, ax=None):
    if ax is None: ax=plt.subplots()[1]
    ax.scatter(time, speed)
    ax.scatter(time, to_np(preds), color='r')
    ax.set_ylim(-300,100)
show_preds(preds)

3. Calculate the loss

the goal is to improve the loss so we will need to know the gradients

loss = mse(preds, speed)
loss
tensor(4422.5112, grad_fn=<MeanBackward0>)

4. Calculate the gradients

Then use these to improve the parameters. We need a learning rate for this

loss.backward()
params.grad
tensor([-20635.2617,  -1319.6385,   -108.2016])
lr = 1e-5
params.grad * lr
tensor([-0.2064, -0.0132, -0.0011])

5. Step the weights

update the parameters based on the gradients we have calculated.

Stepping the weights is w -= gradient(w) * lr

.data is a special attribute in torch that means we don't want the gradient calculated. Here, we do not want the gradient calculated for the step we are doing, we only want the gradiend of the function f to be calculated

params.data -= lr * params.grad.data
params.grad = None # delete the gradients we already had
# check if loss has improved
# previous was 4422.5

preds = f(time, params)
mse(preds, speed)
tensor(1354.7021, grad_fn=<MeanBackward0>)
show_preds(preds)
# the preds have indeed improved!
# repeat a few times

def apply_step(params, prn=True):
    preds = f(time, params)
    loss = mse(preds, speed)
    loss.backward()
    params.data -= lr * params.grad.data
    params.grad = None
    if prn: print(loss.item())
    return preds

6. Repeat the process by looping through and making improvements

# repeat 10 times

for i in range(10):
    apply_step(params)
1354.7021484375
774.1762084960938
664.3204956054688
643.5296630859375
639.5927734375
638.8450317382812
638.7008666992188
638.6709594726562
638.6625366210938
638.6583862304688

We can visualise this to see that for each step, an entirely different quadratic function is being tried

params = orig_params.detach().requires_grad_()
_,axs = plt.subplots(1,4,figsize=(12,3))

for ax in axs: 
    show_preds(apply_step(params, False), ax)
plt.tight_layout()

7. Stop

Summary

We have just seen that by comparing the outputs of our model to our targets using a loss function, we are able to minimize the loss by gradually improving our weights (params).