Backward propagation, often called backpropagation or “backprop” for short, is the central method by which deep learning models learn from error. It’s the process of taking the gradient of the loss function with respect to the model’s parameters and using these gradients to update the parameters to reduce the loss.
The key tool that makes backpropagation possible is the chain rule from calculus. The chain rule allows us to compute the derivative of a composite function. If you have a function that’s composed of other functions, like f(g(h(x))), the derivative of this function with respect to x can be computed as the product of the derivatives of the component functions.
Let’s consider a simple example of a neural network with just one neuron. The output of this neuron, y, is calculated as the activation of the weighted sum of the inputs, x (we’re ignoring the bias for simplicity):
y = a(w*x)
Suppose we use the mean squared error as our loss function, and we have one training example. The loss L is then:
L = (y - t)^2
where t is the target output. The derivative of L with respect to w is what we need for backpropagation. We can compute this using the chain rule:
dL/dw = dL/dy * dy/dw
dL/dy = 2(y - t)
dy/dw = a'(wx) * x
where a’ is the derivative of the activation function.
dL/dw = 2*(y - t) * a'(w*x) * x
In most cases, the loss.backward() line automatically computes dL/dw for all trainable parameters w in your model and stores these gradients in w.grad for each parameter. It does this through automatic differentiation, which is a way to programmatically apply the chain rule to compute the gradients of complex functions.
The gradients computed by the backward function indicate the direction and rate of change of the loss with respect to each parameter. In other words, they tell us how to change the parameters to reduce the loss. When you call optimizer.step(), the optimizer uses these gradients to update the parameters. The simplest way to do this is with gradient descent, where each parameter is updated as follows:
w = w - lr * w.grad
where lr is the learning rate, a small positive number. This update rule moves w in the direction that reduces the loss.