← All Notebooks

Part 4: Automating the Chain Rule

Unlocking composability by building primitives

In part 3, we built an MLP composed of neurons that each understood how to calculate their own gradients. For the backward / update pass, we called each neuron in turn and updated it.

This procedural methodology works in principle for dense / linear layers, but it introduces a challenge when we want to build other kinds of architecture that perform different calculations.

We would need to hand-derive the backward pass for that new architecture, and repeat the process if we wanted to change or adjust anything.

However, there is an abstraction we can make: what if each + or * operation (or any other operation) in the calculation knew how to propagate its own gradient?

Then we could string them together and compose any calculation we want.

We would end up with a nice separation of concerns: we have code that describes the architecture of the network, and other code that describes how to backpropagate and update it.

The modern solution is to do exactly this. Once our models start getting bigger, we will need to switch to libraries like TensorFlow to build and train them. Therefore, it pays to learn how this method works.

Let’s build some primitives.

Calculations as graphs

Let’s consider a calculation performed by a single neuron:

const z = w * x + b // weighted sum
const a = relu(z) // activation
const L = (a - y) ** 2 // loss

We want to know things like ‘how does changing w change L?’, but we can’t directly see the relationship, since w does not directly touch L.

We can describe this as a computational graph, with nodes (where the calculations occur) and edges (which join the calculations).

Note: we are dropping the ReLU for now, since it will just be the value of a, as long as a > 0.

wxw · xzbz + baya − yeL

We can evaluate the expression by setting the inputs and computing the nodes up the graph in a forward (left to right) pass:

wvalue:2xvalue:3w · xzvalue:6bvalue:4z + bavalue:10yvalue:20a − yevalue:-10Lvalue:100

To understand how w affects L, our first option is the obvious one: work forward from w and compute its contributions through the graph all the way to L.

The issue with this is that we would need to repeat this process for every input we wish to examine, so with a million inputs we will need a million passes.

Fortunately, there is a much more efficient approach, which is to use the chain rule to work backward and understand how each node is affected by its parents.

As you will see, we will end up with all the relationships of all the inputs in a single pass.

The backward pass

1. Fix the root at 1

To start, we fix the gradient at L (the root, or end node) to be 1, since L changes 1:1 with itself.

wvalue:2xvalue:3w · xzvalue:6bvalue:4z + bavalue:10yvalue:20a − yevalue:-10Lvalue:100grad:1

2. Find the edges

Now, let’s look at L’s parent, e.

wvalue:2xvalue:3w · xzvalue:6bvalue:4z + bavalue:10yvalue:20a − yevalue:-10Lvalue:100grad:1

When we nudge e by a tiny amount, we can observe what happens to L and thus figure out the local gradient at e:

let e = -10
let h = 1e-6

function getLoss(error) {
  return error * error
}

// Calculate the delta when we nudge by h
let l0 = getLoss(e) // 100
let l1 = getLoss(e + h) // 99.999...
let delta = l1-l0 // -0.00001...

// Divide by h to get the gradient,
// i.e. the delta per unit of h
let derivativeE = delta / h // -19.999999...

Remember, the idea is to imagine h approaching zero — smaller and smaller nudges. So the derivative of L with regard to e is -20.

This gives us a ‘bridge’ value (i.e. the value on the edge) of -20, which describes how e changes L — for every increase of h in e, L decreases by 20 * h:

-20wvalue:2xvalue:3w · xzvalue:6bvalue:4z + bavalue:10yvalue:20a − yevalue:-10Lvalue:100grad:1

3. Figure out the gradient at the node

Now, to figure out the actual gradient at the e node, we need to ‘cross the bridge’ over that edge value from L to e:

// To get the gradient at e, we cross from the L node
// 'over the edge' into the e node. To do this, we 
// multiply the gradient at L (1) by the edge value:
let parentGrad = 1 // L
let gradientE = derivativeE * parentGrad // -20 * 1 = -20

So the gradient at the e node is -20.

-20wvalue:2xvalue:3w · xzvalue:6bvalue:4z + bavalue:10yvalue:20a − yevalue:-10grad:-20Lvalue:100grad:1

4. Rinse and repeat

Let’s look at how y and a affect e next.

-20wvalue:2xvalue:3w · xzvalue:6bvalue:4z + bavalue:10yvalue:20a − yevalue:-10grad:-20Lvalue:100grad:1

First, let’s get the edge values, which are the derivative of e with regard to y and a:

let y = 20, 
    a = 10,
    h = 1e-6

function getE(a, y) {
  return a - y
}

let e0 = getE(a, y) // -10
// Nudge a
let e1 = getE(a + h, y) // -9.999999 
// Nudge y
let e2 = getE(a, y + h) // -10.000001

let derivativeA = (e1 - e0) / h // ≈ 1
let derivativeY = (e2 - e0) / h // ≈ -1

This means that e changes down by 1 when y goes up by 1. When a goes up by 1, e does as well.

1-1-20wvalue:2xvalue:3w · xzvalue:6bvalue:4z + bavalue:10yvalue:20a − yevalue:-10grad:-20Lvalue:100grad:1

The gradients at y and a are given by their derivative multiplied by the gradient at the parent (cross the bridge).

let parentGrad = gradientE // -20

let gradientA = parentGrad * derivativeA // -20 * 1 = -20
let gradientY = parentGrad * derivativeY // -20 * -1 = 20
1-1-20wvalue:2xvalue:3w · xzvalue:6bvalue:4z + bavalue:10grad:-20yvalue:20grad:20a − yevalue:-10grad:-20Lvalue:100grad:1

Let’s move up to the next node, a (z + b) and figure out how z and b change the value. It is another sum operation and we know that sum operations mean changes in the outputs will be reflected 1:1 with changes in the inputs. But let’s step through it to prove it:

let z = 6, b = 4, h = 1e-6

function getWeightedSum(dotProduct, bias) {
  return dotProduct + bias
}

let a0 = getWeightedSum(z,b) // 10
let a1 = getWeightedSum(z+h, b) // 10.000001 
let a2 = getWeightedSum(z, b+h) // 10.000001

let deltaZ = a1 - a0 // 0.000001 
let deltaB = a2 - a0 // 0.000001

// Derivative is 1 for both
let derivativeZ = deltaZ / h // ≈ 1
let derivativeB = deltaB / h // ≈ 1

This gives us our edges:

111-1-20wvalue:2xvalue:3w · xzvalue:6bvalue:4z + bavalue:10grad:-20yvalue:20grad:20a − yevalue:-10grad:-20Lvalue:100grad:1

And we can cross over those edges to get the gradients at b and z:

// Multiply by parent gradient to get the gradients at the nodes:
let parentGrad = gradientA // -20
let gradientZ = derivativeZ * parentGrad // 1 * -20 = -20
let gradientB = derivativeB * parentGrad // 1 * -20 = -20

So the gradients at both nodes will be -20.

111-1-20wvalue:2xvalue:3w · xzvalue:6grad:-20bvalue:4grad:-20z + bavalue:10grad:-20yvalue:20grad:20a − yevalue:-10grad:-20Lvalue:100grad:1

At the z node, we have a multiplication.

111-1-20wvalue:2xvalue:3w · xzvalue:6grad:-20bvalue:4grad:-20z + bavalue:10grad:-20yvalue:20grad:20a − yevalue:-10grad:-20Lvalue:100grad:1

We know from previous chapters that with C = A * B, changes to A result in C changing by B and vice-versa. So let’s apply that here:

let w = 2, x = 3
let h = 1e-6

function getDotProduct(weight, input) {
  return weight * input
}

let z0 = getDotProduct(w, x) // 6
let z1 = getDotProduct(w + h, x) // 6.000003
let z2 = getDotProduct(w, x + h) // 6.000002

let deltaW = z1 - z0 // 0.000003
let deltaX = z2 - z0 // 0.000002

let derivativeW = deltaW / h // ≈ 3
let derivativeX = deltaX / h // ≈ 2

This gives us the edges that connect to the w and x nodes:

32111-1-20wvalue:2xvalue:3w · xzvalue:6grad:-20bvalue:4grad:-20z + bavalue:10grad:-20yvalue:20grad:20a − yevalue:-10grad:-20Lvalue:100grad:1

And finally, we finish our back propagation with the gradients at w and x:

let parentGrad = gradientZ // -20

let gradientW = parentGrad * derivativeW // = -20 * 3 = -60
let gradientX = parentGrad * derivativeX // = -20 * 2 = -40

The completed graph now looks like this:

32111-1-20wvalue:2grad:-60xvalue:3grad:-40w · xzvalue:6grad:-20bvalue:4grad:-20z + bavalue:10grad:-20yvalue:20grad:20a − yevalue:-10grad:-20Lvalue:100grad:1

Notice that in our backward pass, we were able to compute everything in one pass, from L all the way back to w, x, b and y. If you were to try to compute it all the other way (by going forward through the graph), we would need to follow each variable individually to see how it affects L. With a million weights and inputs, this is a lot of work! Back propagation therefore offers an extreme performance increase.

Automating the calculation

Now we have everything we need to build an engine that computes the back propagation for us.

There are actually a number of ways to approach this, as you might imagine.

The way we will approach it is to follow the mental model we have built so far: create a node primitive that holds the current value at a given node, a reference to its parent nodes, and their gradients with respect to the current value.

// A wrapper for a value at a given node
function toNode(value, parents = [], backward = () => {}, op = '') {
  return typeof value == 'number' ?
    {
      // Holds the current value
      value,
      // Holds the parent values
      parents,
      // The gradient at this node
      grad: 0,
      // a method to calculate the gradients of
      // the parent values
      backward,
      // Nice for introspection
      op
    } : value
}

let w = toNode(2)
let x = toNode(3)

w.value // 2
x.value // 3

Now we can encode the operations:

function multiply(a, b) { 
  // Cast the inputs to nodes if they are not already
  a = toNode(a)
  b = toNode(b)
  
  // Create a new node, passing the 
  // new value and the parents
  const node = toNode( 
    a.value * b.value, 
    [a, b]
  )
  node.op = 'multiply'
  node.backward = () => {
    // Travel backward along the edge 
    // to get the parent node
    a.grad += b.value * node.grad
    b.grad += a.value * node.grad
  }

  return node
}

function sum (a, b) {
  a = toNode(a)
  b = toNode(b)
  
  const node = toNode( a.value + b.value, [a, b])
  node.op = 'sum'
  node.backward = () => {
    a.grad += node.grad
    b.grad += node.grad
  }

  return node
}

This means we can perform our calculation (the forward pass), and the graph is created for us.

let w = 2, x = 3, b = 4, y = 20

let z = multiply(w, x)
let a = sum(z, b)
let e = sum(a, -y)

let loss = multiply(e, e)

console.log(loss) // {value: 100, parents: Array(2), op: 'multiply', ...}
console.log(loss.parents.map(p=> p.op)) // ['sum', 'sum']

To perform a backward pass, we can step through each node, calling backward, and then doing the same on its parents and grandparents up the tree.

This is what loops are for, though. So let’s first create a ‘tape’ array of all the nodes. This will let us ‘rewind the tape’ by looping through it, calling backward() on each node.

/**
 Builds a topological ordering of the 
 graph (i.e. a flat array, or 'tape')
*/
function buildTape(v, visited = new Set(), tape = []) {
  if (!visited.has(v)) {
    visited.add(v)
    for (let parent of v.parents) {
      buildTape(parent, visited, tape)
    }
    tape.push(v)
  }
  return tape
}

/**
 A function that kicks off the backward pass
*/
function backward(root) {
  const tape = buildTape(root)
  // Important: always reset first and 
  // ensure all gradients start at 0!
  tape.forEach((node) => (node.grad = 0))
  // Seed the root gradient first
  root.grad = 1.0
  // Traverse the graph in reverse topological order
  for (let node of tape.reverse()) {
    node.backward()
  }
}

Now, let’s run our backward pass automatically.

let w = 2, x = 3, b = 4, y = 20

let z = multiply(w, x)

let a = sum(z, b)
let e = sum(a, -y)

let loss = multiply(e, e)

backward(loss)
// Gradients are all available
loss.parents.map(p => p.grad) // [-20, -20]
// w and x gradients are correct
z.parents.map(p => p.grad) // [-60, -40]

Other operations

Although technically, we can support any operation using sum and multiply, in practice we would want to expose other primitives to make life easier for the consumer. Here are some examples:

function power(a, exponent) {
   a = toNode(a)
   exponent = typeof exponent === 'number'? exponent : exponent.value

   const node = toNode(
     a.value ** exponent,
     [a],
   )
   node.op = `**${exponent}`
   /*
    Chain rule for power: for xⁿ, nxⁿ⁻¹
     Examples:
     let a = 2, h = 0.001
     function power(n, exp) { return n ** exp }
    
     let square0 = power(a, 2) // 4
     let square1 = power(a + h, 2) // 4.004
     (square1 - square0) / h // 4

     let cubed0 = power(a, 3) // 8
     let cubed1 = power(a + h, 3) // 8.012
     (cubed1 - cubed0) / h //  12

     x² → 4 == 2*2¹ = 4
     x³ → 12 = 3*2²= 3 × 4 = 12

     So the rule is: for xⁿ, nxⁿ⁻¹
      */
   node.backward = function () {
     a.grad += exponent * a.value ** (exponent - 1) * node.grad
   }
   return node
 }

function subtract(a, b) {
  return sum(a, multiply(b, -1))
}

function relu(a) {
   a = toNode(a)
   const node = toNode(
     a.value < 0 ? 0 : a.value, 
     [a], 
   )
   node.op = 'ReLU'
   node.backward = function () {
     a.grad += (node.value > 0 ? 1.0 : 0.0) * node.grad
   }
   return node
}

Performance

It’s important to note that this approach, although solid for learning, is going to run into some performance issues once models reach a tipping point in their size. We will have many closures and objects accumulating in memory, and it will make training and inference on the CPU quite slow.

There are lots of ways to mitigate this, including storing the values and gradients in typed arrays, meaning the numerical values can be highly memory-optimised, and we are only ever traversing an array, which offers linear complexity.

In reality, once we have a large number of operations, we will switch to a library that allows them to be performed in parallel on the GPU.

For now, let’s stick with our engine and re-write our MLP to solve XOR.

XOR revisited

import { data, getXandY } from './data.js'
import { mlp } from './model.js'
import { 
  toNode,
  sum,
  multiply,
  power,
  subtract,
  backward,
  relu,
  identity
} from './engine.js'

// 4 hidden neurons using ReLU, then 1 linear output
// using identity
const model = mlp(2, [4, 1], [relu, identity])

function train(model, data, epochs = 1000, learningRate = 0.1) {
  const [Xs, Ys] = getXandY(data)

  console.log(
    'Training model with ' +
    model.parameters().length +
    ' parameters...'
  )

  for (let epoch = 0; epoch < epochs; epoch++) {
    // We will use MSE loss
    let totalLoss = toNode(0)

    Xs.forEach((x, i) => {
      const y = Ys[i]
      const prediction = model.forward(x)
      // Square the error: (prediction - y)**2
      const loss = power(subtract(prediction, y), 2)
      totalLoss = sum(totalLoss, loss)
    })

    // Mean the squared error 
    totalLoss = multiply(totalLoss, 1 / Xs.length)

    // Run the backward pass to compute the gradients
    backward(totalLoss)

    // Update all the parameters
    model.parameters().forEach(p => {
      p.value -= learningRate * p.grad
    })

    if (epoch % 100 === 0) {
      console.log(
        "Epoch " + epoch + ", loss: " + totalLoss.value
      )
    }
  }

  return model
}

train(model, data)

// Test final predictions
console.log("\nFinal predictions:")
for (let { x, y } of data) {
  const prediction = model.forward(x)
  const classification = prediction.value > 0.5 ? 1 : 0
  console.log(
    "x: [" + x + "]" +
    ", expected: " + y +
    ", predicted: " + classification +
    " (raw: " + prediction.value.toFixed(4) + ")"
  )
}

Notes on the code

Identity activation function

This model makes use of an activation function we haven’t covered yet: the identity function. It’s just a pass-through:

function identity(x) {
  return x
}

Its significance is that it allows neurons / layers to just send their output onwards as-is, which is what we want in the output layer here for simplicity.

In a real application, a binary output like XOR, or some other binary classifier would use an activation like tanh or sigmoid, which squashes the output value to a range (-1 to 1 and 0 to 1 respectively).

We’d also use a different loss function, binary cross-entropy loss. However, for this simple example the focus is on the autograd implementation.

Using += in the engine

You’ll notice that the grad calculations in the engine use += rather than =:

// ...
node.backward = () => {
  a.grad += b.value * node.grad
  b.grad += a.value * node.grad
}
// ...

This is really important, because the gradients need to accumulate. Consider this: let loss = multiply(e, e). We are operating on the same node, and if we didn’t let the gradients accumulate, only one of the backward calls would persist.

The tape is built depth-first

Possibly an obvious point, but worth mentioning: when we traverse a graph, we have the choice of doing so breadth-first (i.e. start at the final node and work back up before traveling along the branches) or depth-first (go along each branch as far as you can before visiting the next layer).

The buildTape function uses depth-first traversal — it follows each branch all the way back to the leaf nodes before moving on. This ensures every parent is added to the tape before its children, giving us a valid topological order.

Putting it all together

Here is our computational graph, animated through both the forward and backward passes. You can see how the values propagate during the forward pass, and how the edges and gradients are evaluated during the backward pass.

wvalue:2xvalue:3w · xzbvalue:4z + bayvalue:20a − yeL
Set inputs: w=2, x=3, b=4, y=20
1 / 10

Next up

We’ve built out an MLP and then refactored it so that we can start to play with different layer architectures. Next we will look at an example of a problem that is extremely unwieldy for the MLP, and for which a completely different algorithm needed to be devised.