Makemore Subreddits - Part 3 Activations and Gradients

nlp
makemore
Published

January 7, 2024

Let’s dive deep into the Activations and Gradients in a Multi-Layer Perceptron language model for subreddit names.

This loosely follows part 3 of Andrej Karpathy’s excellent makemore; go and check that out first. However he used a list of US names, where we’re going to use subreddit names. See Makemore Subreddits - Part 2 MLP for the original Multi-Layer Perceptron following from Bengio et al. 2003.

Note

This is a Jupyter notebook you can download the notebook or view it on Kaggle.

Loading the Data

This is largely similar to Part 1 where we get the most common subreddit names from All Subreddits and Relations Between Them.

import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
import random
import csv
from pathlib import Path
from collections import Counter
from tqdm.auto import trange, tqdm

Filter to subreddits that:

  • Have at least 1000 subscribers
  • Are not archived
  • Are safe for work
  • And are not quarantined

Note that you need to have downloaded subreddits.csv (and uncompresed if appropriate).

data_path = Path('./data')
min_subscribers = 1_000

with open(data_path / 'subreddits.csv', 'r') as f:
    names = [d['name'] for d in csv.DictReader(f) 
             if int(d['subscribers'] or 0) >= min_subscribers
             and d['description']
             and d['type'] != 'archived'
             and d['nsfw'] == 'f'
             and d['quarantined'] == 'f']
    
len(names)
random.seed(42)
random.shuffle(names)

N = len(names)

names_train = names[:int(0.8*N)]
names_val = names[int(0.8*N):int(0.9*N)]
names_test = names[int(0.9*N):]


for name in names_train[:10]:
    print(name)

len(names_train), len(names_val), len(names_test)
splunk
thenwa
soylent
factorio
christinaricci
blues
vegancheesemaking
goldredditsays
reformed
nagoya
(26876, 3359, 3360)

Compile the Data

Now convert the dataset into something that the model can easily work with. First represent all the character tokens as consecutive integers. We create a special PAD_CHAR with index 0 to represent tokens outside of the sequence.

PAD_CHAR = '.'
PAD_IDX = 0

i2s = sorted(set(''.join(names_train)))
assert PAD_CHAR not in i2s
i2s.insert(PAD_IDX, PAD_CHAR)

s2i = {s:i for i, s in enumerate(i2s)}

V = len(i2s)

def compile_dataset(names, block_size, PAD_CHAR=PAD_CHAR, s2i=s2i):
    X, y = [], []
    for name in names:
        padded_name = PAD_CHAR * block_size + name + PAD_CHAR
        padded_tokens = [s2i[c] for c in padded_name]
        for *context, target in zip(*[padded_tokens[i:] for i in range(block_size+1)]):
            X.append(context)
            y.append(target)
    return torch.tensor(X), torch.tensor(y)


block_size = 3

X, y = compile_dataset(names_train, block_size)
X_val, y_val = compile_dataset(names_val, block_size)
X_test, y_test = compile_dataset(names_test, block_size)

X.shape, y.shape
(torch.Size([330143, 3]), torch.Size([330143]))

Review: Multi-Layer Perceptron

We will start with the Multi-Layer perceptron implementation from part 2

default_m = 30
default_h = 200

class MLP:
    def __init__(self, m=default_m, h=default_h, V=V, block_size=block_size):
        self.m = m
        self.h = h
        self.V = V
        self.block_size = block_size
        
        # Word embedding layer
        self.C = torch.randn(V, m)

        # First hidden layer
        self.H = torch.randn(block_size * m, h)
        self.d = torch.randn(h)

        # Second hidden layer
        self.U = torch.randn(h, V)
        self.b = torch.randn(V)
        
    def parameters(self):
        return [self.C, self.H, self.d, self.U, self.b]
    
    def requires_grad_(self, requires_grad=True):
        for p in self.parameters():
            p.requires_grad_(requires_grad)
        return self
    
    def zero_grad(self):
        for p in self.parameters():
            p.grad = None
        return self
    
    def forward(self, X):
        self.embeddings = self.C[X]
        self.hidden_layer = self.embeddings.view(X.shape[0], self.block_size * self.m) @ self.H + self.d
        self.hidden_activations =  torch.tanh(self.hidden_layer)
        self.output_logits = self.hidden_activations @ self.U + self.b
        return self.output_logits
    
    def __call__(self, X):
        return self.forward(X)

So we can start from a randomly initialised MLP:

mlp = MLP().requires_grad_()

with torch.no_grad():
    preds_val = mlp(X_val)
    val_loss = F.cross_entropy(preds_val, y_val).item()
val_loss
27.776723861694336

And code to sample from it (which gives random output):

def sample(mlp, pad_idx=PAD_IDX, block_size=block_size, i2s=i2s, generator=None):
    ans = []
    state = torch.tensor([[pad_idx] * block_size])
    while True:
        probs = mlp(state).softmax(axis=1)
        next_idx = torch.multinomial(mlp(state).softmax(axis=1), 1, generator=generator)
        state = torch.concat([state, next_idx], axis=1)[:,1:]
        
        next_idx = next_idx[0,0].item()
        if next_idx == pad_idx:
            return ''.join(ans)
        ans.append(i2s[next_idx])
        
sample(mlp)
'qaxtgkseh9ekyf9uokp29euoy4a7igiareg'

And code to train it:

batch_size = 32
val_step = 100

def train(model, n_step, lr, batch_size=batch_size, val_step=val_step, X=X, y=y, X_val=X_val, y_val=y_val, callback=None):
    losses, val_losses = [], []
    
    for step in trange(n_step):
        model.training = True  # NEW: support models that are different at train and inference time
        idx = torch.randint(0, len(X), (batch_size,))
    
        model.zero_grad()
        logits = model(X[idx])
        loss =  F.cross_entropy(input=logits, target=y[idx])
        losses.append((step, loss.item()))
        loss.backward()

        # NEW: Support injecting a callback to do some mutation
        if callback is not None:
            callback(**locals())
        
        for p in model.parameters():
            p.data -= p.grad * lr(step, n_step)
        
        if step % val_step == 0:
            model.training = False # NEW: support models that are different at train and inference time
            with torch.no_grad():
                preds_val = model(X_val)
                val_loss = F.cross_entropy(preds_val, y_val).item()
                val_losses.append((step, val_loss))

    model.training = False # NEW
    return losses, val_losses
losses, val_losses = train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1)

The loss decreases quickly from a very high value, and then slowly descends.

val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

baseline_loss = val_loss_value[-1]
baseline_loss
2.6191823482513428

We get samples that are much less random, but still don’t really seem like subreddit names.

for _ in range(20):
    print(sample(mlp))
lolleckingurs
r64onfrimen
caryimalcuraogardcartfgatulgeterheisting
clan
voumg
hvur
stagb
monconts
vart
scaromark
uprycherce
wing
lok
crteyrar
ling
jerkeymentfumearnight
kurcudver
learovedx
osarenetroor
ct2ivzenmirdur34therrertifar

Fixing Initialisation

The loss starts off very high giving the “hockey-stick” shaped loss curve. We can fix this by starting the weights in a better place.

The initial loss is very high:

mlp = MLP()

with torch.no_grad():
    preds_val = mlp(X_val)
    val_loss = F.cross_entropy(preds_val, y_val).item()
val_loss
29.860868453979492

A good baseline model is the constant model that assigns each token in the vocabulary an equal probability

constant_model = 1/V * torch.ones(size=(len(X_val), V))

Such a model has a 1/V probability of having the correct prediction, and so the loss will be -log(1/V)

-torch.log(torch.tensor(1/V)).item()
3.6375861167907715

This is accurate to a very good approximation, and much better than the random weights.

F.cross_entropy(constant_model, y_val).item()
3.637585163116455

Walking through the model

We want to initialise the weights so that the model predicts close to a random distribution of outputs.

Let’s step through the layers of our current model for a batch of training data to understand what is currently happening:

self.embeddings = self.C[X]
self.hidden_layer = self.embeddings.view(X.shape[0], self.block_size * self.m) @ self.H + self.d
self.hidden_activations =  torch.tanh(self.hidden_layer)
self.output_logits = self.hidden_activations @ self.U + self.b
with torch.no_grad():
    batch = mlp(X[:32])

The embeddings have a sort of standard normal distribution; slightly distorted by the item frequency.

self.C = torch.randn(V, m)
...
self.embeddings = self.C[X]
plt.hist(mlp.embeddings.view(-1), bins=50);
mlp.embeddings.shape, mlp.embeddings.std()
(torch.Size([32, 3, 30]), tensor(0.9805))

The first hidden layer performs a linear transformation.

# First hidden layer
self.H = torch.randn(block_size * m, h)
self.d = torch.randn(h)
...
self.hidden_layer = self.embeddings.view(X.shape[0], self.block_size * self.m) @ self.H + self.d

In index notation the hidden layer \(h\) looks like:

\[h_{i,k} = \sum_{j=1}^{m \times {\rm block\_size}} e_{i,j} H_{j,k} + d_{i,k}\]

And assuming the variables are all independent, and that the embeddings e and matrix H each consist of elements of zero mean and equal standard deviation then \(\mathbb{E}(h) = 0\) and

\[ \mathbb{V}(h_{i,k}) = (m \times {\rm block\_size}) \mathbb{V}(e) \mathbb{V}(H) + \mathbb{V}(d)\]

So in particular here we’ve set all the element variances to 1, and so the output variance should be:

mlp.block_size * mlp.m + 1
91

It’s pretty close to this (with some random error)

plt.hist(mlp.hidden_layer.view(-1), bins=50);
mlp.hidden_layer.shape, mlp.hidden_layer.var()
(torch.Size([32, 200]), tensor(83.0976))

We then perform a tanh transformation, which squishes values far from 1 towards 1.

self.hidden_activations =  torch.tanh(self.hidden_layer)
x = torch.arange(-3, 3, step=0.01)
plt.plot(x, torch.tanh(x));

Consequently we get all our values squished around -1 and 1

plt.hist(mlp.hidden_activations.view(-1), bins=50);
mlp.hidden_activations.shape, mlp.hidden_activations.var()
(torch.Size([32, 200]), tensor(0.9154))

Many of the activations are above 0.99, which means the gradient is \(\tanh'(x) = 1-\tanh^2(x) < 0.02\), which can lead to gradient underflow.

plt.matshow(mlp.hidden_activations.abs() > 0.99, cmap='gray', interpolation='nearest');

The output logits are then mostly either U + b or -U + b, and so they are approximately normal.

self.output_logits = self.hidden_activations @ self.U + self.b

It has a huge variance, approximately \[{\rm hidden\_size} \times \mathbb{V}(U) + \mathbb{V}(b) = \rm hidden\_size+ 1\]

mlp.h + 1
201
plt.hist(mlp.output_logits.view(-1), bins=50);
mlp.output_logits.shape, mlp.output_logits.var()
(torch.Size([32, 38]), tensor(190.1738))

This means the logits fluctuate wildly and the predictions are very extreme with a very high probability prediction:

plt.matshow(mlp.output_logits.softmax(axis=1), cmap='gray');

We can fix these simply by scaling down the activations; this is much more important in deep networks where these effects compound often leading to exploding or vanishing gradients.

This is explained clearly in Xavier Glorot and Yoshua Bengio’s Understanding the difficulty of training deep feedforward neural networks where they derive the variance ignoring non-linearities for the backward and forward pass and suggest initialising with

\[ w \sim U\left(-\frac{\sqrt{6}}{\sqrt{n_{\rm in} + n_{\rm out}}}, \frac{\sqrt{6}}{\sqrt{n_{\rm in} + n_{\rm out}}}\right)\]

Alternatively their analysis suggests you could also use

\[ w \sim \mathcal{N}\left(\mu=0, \sigma = \frac{\sqrt{2}}{\sqrt{n_{\rm in} + n_{\rm out}}}\right)\]

or approximately, \(w \sim \mathcal{N}\left(\mu=0, \sigma = 1/\sqrt{n_{\rm in}}\right)\) and this latter form is called Xavier initialisation or Glorot initialisation. They show these allow training CNNs up to 9 layers deep which was difficult without this.

In Delving Deep into Rectifiers He, Zheng, Ren, and Sun take into account the ReLU non-linearity and show you need to introduce a gain of \(\sqrt{2}\) (the He initialisation or Kaiming Initialisation after the first author), which allows them to go from 22 layers to 30.

For other non-linearities Siddharth Krishna Kumar derivies the variance in On weight initialization in deep neural networks of a differentiable activation function \(g\) and uses a local expansion to derive an initialisation of

\[ w \sim \mathcal{N}\left(\mu=0, \sigma = \frac{1}{\left|g'(0)\right|\sqrt{n (1 + g(0))^2}}\right) \]

For tanh this suggests a gain of 1, but as Andrej Karpathy argues in this lecture since tanh is contractive the gain must be more than 1, but this would require a higher-order approximation.

PyTorch has a gain of 5/3, but no one remembers why. It’s likely an estimate of the variance of \(\tanh\) under a standard normally distributed input:

\[ \begin{align} \mathbb{V}\left[\tanh\right] &= \mathbb{E}\left[\left(\tanh - \mathbb{E}(\tanh)\right)^2\right] \\ &= \mathbb{E}\left[\tanh^2\right] \\ &= \int_{-\infty}^{\infty} \left(\frac{1}{\sqrt{2 \pi}} e^{-x^2/2} \right) \tanh^2(x)\,{\rm d}x \\ & \approx 0.394 \end{align}\]

To normalise the gain we need to divide by the standard deviation (the square root of the variance) \({\rm gain} \approx 1/\sqrt{0.394} \approx 1.59 \approx 5/3\) (where the last term introduces an error of around 5%).

In general even if it is difficult to calculate the variance (for example because we’re unsure of the input distribution), it can be empirically derived, as in All you Need is a Good Init by Mishkin and Matas with their iterative Layer-sequential Unit Variance (LSUV) method.

In any case it won’t matter with our 1-layer MLP, and other methods such as batch/layer normalisation, skip connections, and better optimisers have made these less important (but you can train a transformer without the normalisation and skip connections).

For simplicity we’ll use Kaiming initialisation, setting the initialisation of the biases to a very low number and further scaling down the output logits to get a better initial loss.

def fix_init(mlp, bias_variance=1e-4, output_variance=1e-1):
    with torch.no_grad():
        mlp.H *= 1/(mlp.block_size*mlp.m)**0.5
        mlp.d *= bias_variance ** 0.5

        mlp.U *= output_variance**0.5/(mlp.h**0.5)
        mlp.b *= bias_variance ** 0.5
    return mlp
        
mlp = fix_init(MLP()).requires_grad_()

We can see out initial loss is much lower, and closer to random.

with torch.no_grad():
    preds_val = mlp(X_val)
    val_loss = F.cross_entropy(preds_val, y_val).item()
    
val_loss
3.64780330657959
F.cross_entropy(torch.ones_like(preds_val), y_val).item()
3.637585163116455

As before we can step through the layers

with torch.no_grad():
    preds = mlp(X[:32])

The hidden pre-activations are now standard normal

plt.hist(mlp.hidden_layer.view(-1), bins=50);
mlp.hidden_layer.shape, mlp.hidden_layer.var()
(torch.Size([32, 200]), tensor(1.0259))

The activations are slightly saturated, but uniform-ish.

plt.hist(mlp.hidden_activations.view(-1), bins=50);
mlp.hidden_activations.shape, mlp.hidden_activations.var()
(torch.Size([32, 200]), tensor(0.3962))

The output layer is now standard normal with the variance we set.

plt.hist(mlp.output_logits.view(-1), bins=50);
mlp.output_logits.shape, mlp.output_logits.std()
(torch.Size([32, 38]), tensor(0.1852))

The predicted probabilities are much more uniform.

plt.matshow(mlp.output_logits.softmax(axis=1), cmap='gray')
<matplotlib.image.AxesImage at 0x7f37595b3910>

losses, val_losses = train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1)

The loss doesn’t have as sharp an initial drop-off and reaches a lower value (before it was ~2.6)

val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

fix_init_loss = val_loss_value[-1]

fix_init_loss, f'{fix_init_loss / baseline_loss:0.2%} of baseline loss'
(2.4480857849121094, '93.47% of baseline loss')

We can run samples but they’re not qualitatively better than before:

for _ in range(20):
    print(sample(mlp))
cata
sokemolos
clsybeafwensterthewlukening
therez
breadrycoonn
kritefactreamancoonscotterry
balarcare
coreporn
clang
neleng
cretages
reddits
comporn
curans
ristoryardsteton
nongoreddit
terbateas
millippertoockpharnothenesek
acepapertrole
deundumpstbroknv

Batch Norm

Another way to control the distribution of the pre-activations is to rescale them to be in that distribution. The challenge here is that we need to estimate the distribution of the weights somehow. Batch Norm does this by calculating the statistics across a batch, which works both to normalise and regularise, but coupling examples across a batch makes the process more complicated and error prone.

To implement it we need to add:

  • Learnable shift and scale parameters
  • Fixed eps
  • In the forward
    • normalise based on batch statistics
    • rescale with shift and scale parameters

Let’s also track how the statistics change over time

class MLPBatchNorm(MLP):
    def __init__(self, m=default_m, h=default_h, V=V, block_size=block_size, bn_eps=1e-8):
        super().__init__(m=m, h=h, V=V, block_size=block_size)
        # New stuff
        self.bn_scale = torch.ones((1,h))
        self.bn_shift = torch.zeros((1,h))
        self.bn_eps = bn_eps

        # Track statistics for debugging
        self.bn_means = []
        self.bn_vars = []
        
    def parameters(self):
        return super().parameters() + [self.bn_scale, self.bn_shift]
    
    def forward(self, X, bn_mean=None, bn_var=None):
        self.embeddings = self.C[X]
        self.hidden_layer = self.embeddings.view(X.shape[0], self.block_size * self.m) @ self.H + self.d
        # New stuff; allow passing in a batch norm mean and variance for debugging
        μ = self.hidden_layer.mean(dim=0, keepdim=True) if bn_mean is None else bn_mean
        σ2 = self.hidden_layer.var(dim=0, keepdim=True) if bn_var is None else bn_var
        self.hidden_bn = self.bn_scale * (self.hidden_layer - μ) /2 + self.bn_eps) ** 0.5 + self.bn_shift        
        self.hidden_activations =  torch.tanh(self.hidden_bn)

        # Track statistics only in training (not validation) for debugging
        if self.training:
            self.bn_means.append(μ.detach())
            self.bn_vars.append(σ2.detach())
        # End new stuff
        self.output_logits = self.hidden_activations @ self.U + self.b
        return self.output_logits
    
mlp = fix_init(MLPBatchNorm()).requires_grad_()
losses, val_losses = train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1)

Training this we get a similar loss value as before:

val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

bn_loss = val_loss_value[-1]

bn_loss, f'{bn_loss / fix_init_loss:0.2%} of fix_init loss'
(2.4308383464813232, '99.30% of fix_init loss')

But there’s a problem, we can no longer make predictions on a single example because we normalise it away when calculating statistics:

mlp(X_val[:1])
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
       grad_fn=<AddBackward0>)

One way to handle this is to use the batch norm mean and variation on the training set

with torch.no_grad():
    preds = mlp(X)
    preact = mlp.hidden_layer
    bn_mean = preact.mean(dim=0)
    bn_var = preact.var(dim=0)

For each one of the 200 dimensions we have a mean and a standard deviation

plt.bar(range(len(bn_mean)), bn_mean);

plt.bar(range(len(bn_mean)), bn_var);

We can use these population statistics as the batch norm mean and variance. This also means that the predictions will be independent of the other items in the batch.

mlp.forward(X_val[:1], bn_mean, bn_var)
tensor([[-2.2420, -4.6427, -2.1474, -1.2763, -2.2146, -2.6077, -3.2863, -3.3461,
         -3.4700, -3.0928, -2.5705, -2.5806,  2.0783,  2.1519,  3.0222,  1.6728,
          1.3640,  2.1493,  1.9894,  1.2325,  1.4388,  0.6066,  0.8470,  1.6142,
          2.5117,  1.8061,  0.8732,  1.9734, -1.2449,  1.3586,  2.1918,  2.8410,
          1.2704,  0.9921,  1.4396, -0.9299, -0.2703, -0.8756]],
       grad_fn=<AddBackward0>)

We get essentially the same loss here as well (but if there were distribution shift between the training and validation sets then there could be a substantial difference using the validation statistics, which may not be possible in an online setting).

with torch.no_grad():
    preds_val = mlp.forward(X_val, bn_mean, bn_var)
    val_loss = F.cross_entropy(preds_val, y_val).item()

val_loss, f'{val_loss / bn_loss:0.2%} of bn_loss loss'
(2.4395534992218018, '100.36% of bn_loss loss')

Running statistics

Calculating the training statistics after is an additional step, could we estimate them on the fly?

Let’s looks at the batch norm means

train_bn_means = torch.concat(mlp.bn_means)
train_bn_vars = torch.concat(mlp.bn_vars)
train_bn_means.shape
torch.Size([50000, 200])

Let’s look at a single dimension of the hidden layer. Each point is the average value of the output over a single batch, it changes over time because of:

  • random variations between batches
  • changes in the parameters of the embeddings and hidden layer

We can see here that the value and variance changes over time (the final training set mean, m_inf, is in red)

idx = 0
m = train_bn_means[:,0]
m_inf = bn_mean[0]
plt.plot(m)
plt.hlines(m_inf, 0, len(m), color='r');

This means if we took the simple average of points we severely mis-estimate the final value because it’s moving.

m.mean(), m_inf
(tensor(-0.5176), tensor(-0.9794))

Similarly if we just took the last value it may mis-estimate because of the variance

m[-1], m_inf
(tensor(-0.6533), tensor(-0.9794))

One option between the two extremes is to use an Exponential Moving Average to track the value over time:

momentum = 0.1
m_ema = []
ema = m[0]
for m_i in m:
    ema = ema * (1 - momentum) + m_i * momentum
    m_ema.append(ema)

plt.plot(m)
plt.plot(m_ema)
plt.hlines(m_inf, 0, len(m), color='r')

ema.item(), m_inf.item()
(-0.978546679019928, -0.9794382452964783)

The one hyperparameter we have to tune is the momentum; how much of the previous value do we keep in each step.

  • A momentum that is too high means we will get too much variance (momentum = 1 gives the last value)
  • A momentum that is too low will not respond quickly enough to the changes of parameters (too much bias)
momentum = 0.00001
m_ema = []
ema = m[0]
for m_i in m:
    ema = ema * (1 - momentum) + m_i * momentum
    m_ema.append(ema)

plt.plot(m)
plt.plot(m_ema)
plt.hlines(m_inf, 0, len(m), color='r')

m_ema[-1].item(), m_inf.item()
(-0.43218255043029785, -0.9794382452964783)

The way this works is we get an exponentially decaying weight on old values; with a momentum \(\alpha\), the exponential moving average estimates \(y_i\) are given recursively as the weighted interpolation of the last estimate and the next value: $ y_i = (1-) y_{i-1} + x_i$ and so:

\[\begin{eqnarray} y_0 &=& x_0 \\ y_1 &=& \alpha x_1 + (1 - \alpha) y_0 \\ &=& \alpha x_1 + (1 - \alpha) x_0 \\ y_2 &=& \alpha x_2 + (1 - \alpha) y_1 \\ &=& \alpha x_2 + (1 - \alpha) \alpha x_1 + (1- \alpha) ^2 x_0 \\ &\vdots& \\ y_n &=& \alpha x_n + (1 - \alpha) y_{n-1} \\ &=& \alpha x_n + (1 - \alpha) \alpha x_{n-1} + (1 - \alpha)^2 \alpha x_{n-2} + \cdots + (1-\alpha)^{n-1} \alpha x_1 + (1- \alpha)^n x_0 \end{eqnarray}\]

The weights sum to 1 by the geometric series:

\[ 1 + (1 - \alpha) + (1-\alpha)^2 + \ldots + (1-\alpha)^{n-1} = \frac{1 - (1 - \alpha)^n}{\alpha}\]

For large enough \(n\) we can ignore the last term and the terms approximately sum to 1

sum((1-momentum) ** torch.arange(len(m) - 1, -1, -1)) * momentum
tensor(0.3933)

We can then calcuate the exponential moving average quickly using the formula:

\[ y_n = \alpha \left((1-\alpha)^0 x_n + (1-\alpha)^1 x_{n-1} + \cdots + (1-\alpha)^n x_0 \right) + (1-\alpha)^{n+1} x_0\]

def fastema(z, momentum):
    weights = (1-momentum) ** torch.arange(z.shape[-1] - 1, -1, -1)
    return momentum * (z * weights).sum(axis=-1) + (1 - momentum)**(len(weights)) * z[...,0]

Check this gives the right answer

z = torch.rand(5)

momentum = 0.1
ema = z[0]
for m_i in z:
    ema = ema * (1 - momentum) + m_i * momentum

ema.item(), fastema(z, momentum).item(), torch.allclose(ema, fastema(z, momentum))
(0.5154919624328613, 0.5154920220375061, True)

It even works in two dimensions

z2 = torch.rand(5)
zz = torch.stack([z, z2])
zz.shape
torch.Size([2, 5])

Giving the same result over two dimensions

fastema(zz, momentum), fastema(z, momentum), fastema(z2, momentum)
(tensor([0.5155, 0.5326]), tensor(0.5155), tensor(0.5326))

We can now compare how good the estimate is for different values of momentum; in this case it’s best around \([10^{-3}, 10^{-2}]\)

momentums_log10 = torch.arange(-5, 0, 0.1)
momentums = 10**momentums_log10
emas = [fastema(train_bn_means.transpose(0, 1), m) for m in momentums]

emas_rms_error = [((ema - bn_mean)**2).mean()**0.5 for ema in emas]

plt.plot(momentums_log10, emas_rms_error);

We get a similar result for the variances

emas = [fastema(train_bn_vars.transpose(0, 1), m) for m in momentums]

emas_rms_error = [((ema - bn_var)**2).mean()**0.5 for ema in emas]

plt.plot(momentums_log10, emas_rms_error);

We can also look at the errors

ema = fastema(train_bn_means.transpose(0, 1), momentum=0.001)
plt.hist(ema - bn_mean, bins=20);
((ema - bn_mean)**2).mean() ** 0.5
tensor(0.0327)

Note that the optimum momentum will be a factor of the batch size, as batch size increases:

  • the variance within each step will decrease
  • the number of steps in an epoch will decrease
  • optimum momentum will increase
mlp = fix_init(MLPBatchNorm()).requires_grad_()
losses, val_losses = train(mlp, n_step=50_000//10, lr=lambda step, n_step: 0.1, batch_size=batch_size*10)

plt.plot(val_loss_step, val_loss_value);

train_bn_means = torch.concat(mlp.bn_means)
train_bn_vars = torch.concat(mlp.bn_vars)

with torch.no_grad():
    preds = mlp(X)
    preact = mlp.hidden_layer
    bn_mean = preact.mean(dim=0)
    bn_var = preact.var(dim=0)

We can see that the optimum momentum estimators get higher

momentums_log10 = torch.arange(-5, 0, 0.1)
momentums = 10**momentums_log10
emas = [fastema(train_bn_means.transpose(0, 1), m) for m in momentums]

emas_rms_error = [((ema - bn_mean)**2).mean()**0.5 for ema in emas]

plt.plot(momentums_log10, emas_rms_error);

emas = [fastema(train_bn_vars.transpose(0, 1), m) for m in momentums]

emas_rms_error = [((ema - bn_var)**2).mean()**0.5 for ema in emas]

plt.plot(momentums_log10, emas_rms_error);

Running statistics during training

We can now wrap this in our MLP:

  • at training time collect the running statistics
  • at inference time use the running statistics

Note that this only requires 2 extra variables per hidden dimension (1 for mean and 1 for variance).

class MLPBatchNorm(MLP):
    def __init__(self, m=default_m, h=default_h, V=V, block_size=block_size,
                 bn_eps=1e-8, bn_momentum = 0.001):
        super().__init__(m=m, h=h, V=V, block_size=block_size)
        self.bn_scale = torch.ones((1,h))
        self.bn_shift = torch.zeros((1,h))
        self.bn_eps = bn_eps
        # New stuff
        self.training = True
        self.bn_runvar =  torch.ones((1,h))
        self.bn_runmean = torch.zeros((1,h))
        self.bn_momentum = bn_momentum

        
    def parameters(self):
        return super().parameters() + [self.bn_scale, self.bn_shift]
    
    def forward(self, X):
        self.embeddings = self.C[X]
        self.hidden_layer = self.embeddings.view(X.shape[0], self.block_size * self.m) @ self.H + self.d
        if self.training:
            # Estimate batch mean and variance
            μ = self.hidden_layer.mean(dim=0, keepdim=True)
            σ2 = self.hidden_layer.var(dim=0, keepdim=True)
            
            # Update running totals
            with torch.no_grad():                
                self.bn_runmean = (1 - self.bn_momentum) * self.bn_runmean + \
                                  self.bn_momentum * μ
                self.bn_runvar = (1 - self.bn_momentum) * self.bn_runvar + \
                                 self.bn_momentum * σ2

        else:
            μ = self.bn_runmean
            σ2 = self.bn_runvar
            
        self.hidden_bn = self.bn_scale * (self.hidden_layer - μ) /2 + self.bn_eps) ** 0.5 + self.bn_shift        
        self.hidden_activations =  torch.tanh(self.hidden_bn)
        self.output_logits = self.hidden_activations @ self.U + self.b
        return self.output_logits

This gives a similar loss as before

mlp = fix_init(MLPBatchNorm()).requires_grad_()
losses, val_losses = train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1)

val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

bn_run_loss = val_loss_value[-1]

bn_run_loss, f'{bn_run_loss / bn_loss:0.2%} of batch norm loss'
(2.427150011062622, '99.85% of batch norm loss')

But now we can evaluate it on single examples

mlp(X_val[:1])
tensor([[-1.8919, -4.8599, -1.5227, -1.5331, -1.6830, -2.4410, -3.0484, -2.9386,
         -3.2406, -3.1063, -2.3455, -3.1913,  2.2928,  1.8896,  2.5483,  2.3309,
          1.4567,  2.0108,  0.8814,  1.8238,  0.9417,  0.6373,  0.9153,  1.5311,
          2.3831,  1.3458,  1.2359,  2.2754, -0.9118,  1.6478,  2.0318,  2.5369,
          0.3390,  0.4008,  1.6127, -1.4685, -0.6556, -0.8556]],
       grad_fn=<AddBackward0>)

And in evaluation mode the resutls are independent of the batch size:

torch.allclose(mlp(X_val[:1]), mlp(X_val[:100])[:1])
True

Let’s check our running statistics are similar to calculating them after the fact.

with torch.no_grad():
    preds = mlp(X)
    preact = mlp.hidden_layer
    bn_mean = preact.mean(dim=0, keepdim=True)
    bn_var = preact.var(dim=0, keepdim=True)

They’re mostly similar, though some are substantially different.

plt.hist(bn_mean - mlp.bn_runmean, bins=20);

((bn_mean - mlp.bn_runmean)**2).mean() ** 0.5
tensor(0.0261)

plt.hist(mlp.bn_runvar - bn_var, bins=20)

((bn_var - mlp.bn_runvar)**2).mean() ** 0.5
tensor(0.1085)

Pytorchifying

With Batchnorm it’s getting hard to maintain all this spaghetti code, so let’s make it more modular like PyTorch.

We’ll start off with a simple Module class that’s a simple version of PyTorch’s nn.Module

class MyModule:
    def __init__(self):
        self.training = True
        self._parameters = []
        
    def train(self, mode=True):
        self.training = mode
        return self
    
    def parameters(self):
        return self._parameters
        
    def requires_grad_(self, requires_grad=True):
        for p in self.parameters():
            p.requires_grad_(requires_grad)
        return self
    
    def zero_grad(self):
        for p in self.parameters():
            p.grad = None
        return self
    
    def forward(self, X):
        raise NotImplemented()
    
    def __call__(self, X):
        return self.forward(X)

    def __repr__(self):
        return f"{type(self).__name__}"

Linear Layer

Then for our MLP we’ll need a Linear layer, and we’ll copy their initialisation:

\[ w \sim U\left(-1/\sqrt{\rm in\_features}, 1/\sqrt{\rm in\_features}\right) \]

Note this is missing the \(\sqrt{3}\) from being uniform, and any activation specific gain.

Torch doesn’t have a handy way of building a uniform distribution, so we will roll our own:

def rand_unif(shape, min_val, max_val):
    return torch.rand(shape) * (max_val - min_val) + min_val

plt.hist(rand_unif((10_000,), -1, 2), bins=20);

class MyLinear(MyModule):
    def __init__(self, in_features: int, out_features: int,
                 bias: bool = True):
        super().__init__()
        
        scale = 1/(in_features)**0.5
        
        self.weight = rand_unif((in_features, out_features), -scale, scale)
        self._parameters = [self.weight]

        if bias:
            self.bias = rand_unif(out_features, -scale, scale)
            self._parameters.append(self.bias)

    def __repr__(self):
        return f"{type(self).__name__}({self.weight.shape[0]}, {self.weight.shape[1]}, bias={hasattr(self, 'bias')})"
        
    def forward(self, X):
        self.out = X @ self.weight
        if hasattr(self, "bias"):
             self.out += self.bias
        return self.out

We can create a linear layer, and it has the appropriate mean and standard deviation

linear = MyLinear(100, 200)
linear
MyLinear(100, 200, bias=True)
linear.weight.mean(), 10 * (3)**0.5 * linear.weight.std()
(tensor(0.0002), tensor(0.9963))

As do the biases

linear.bias.mean(), 10 * (3)**0.5 * linear.bias.std()
(tensor(-0.0005), tensor(1.0303))

And it has the required parameters

[p.shape for p in linear.parameters()]
[torch.Size([100, 200]), torch.Size([200])]

And it converts a batch of 100 dimensional tensor into a batch of 200 dimensional tensors

linear(torch.randn(32, 100)).shape
torch.Size([32, 200])
linear_nobias = MyLinear(200, 100, bias=False)
[p.shape for p in linear_nobias.parameters()]
[torch.Size([200, 100])]

Embedding Layer

We can similarly abstract the embedding layer, ala torch.nn.embedding, which is much simpler:

class MyEmbedding(MyModule):
    def __init__(self, num_embeddings, embedding_dim):
        super().__init__()
        self.weight = torch.randn(size=(num_embeddings, embedding_dim))
        
        self._parameters = [self.weight]

    def __repr__(self):
        return f"{type(self).__name__}{tuple(self.weight.shape)}"
        
    def forward(self, X):
        self.out = self.weight[X]
        return self.out

MyEmbedding(2,3)
MyEmbedding(2, 3)

Batch Norm

We can similarly implement BatchNorm1d, which is more complex, but hides all the state inside the object which makes for a cleaner abstraction.

class MyBatchNorm1d(MyModule):
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        self.affine = affine
        self.weight = torch.ones(num_features)
        self.bias = torch.zeros(num_features)
        if affine:
            self._parameters = [self.weight, self.bias]
        else:
            self._parameters = []
            
        self.track_running_stats = track_running_stats
        if track_running_stats:
            self.running_mean = torch.zeros(1, num_features)
            self.running_var = torch.ones(1, num_features)
        else:
            self.running_mean = None
            self.running_var = None

    def __repr__(self):
        return f"{type(self).__name__}({self.num_features}, eps={self.eps}, affine={self.affine})"
        
    def forward(self, X):
        if self.training:
            batch_mean = X.mean(dim=0, keepdim=True)
            batch_var = X.var(dim=0, keepdim=True, correction=0)
            
            if self.track_running_stats:
                with torch.no_grad():
                    self.running_mean *= 1 - self.momentum
                    self.running_mean += self.momentum * batch_mean.view(-1)
                    
                    self.running_var *= 1 - self.momentum
                    # Following documentation in Pytorch BatchNorm1D
                    self.running_var += self.momentum * X.var(dim=0, keepdim=True, correction=1)
        else:
            batch_mean = self.running_mean
            batch_var = self.running_var
            
        self.out = self.weight * (X - batch_mean) / batch_var + self.bias
        return self.out

MyBatchNorm1d(5)
MyBatchNorm1d(5, eps=1e-05, affine=True)

MLP

We now have most of the pieces, we just need to add a few more to create our MLP.

Firstly we will need our activations such as nn.Tanh:

class MyTanh(MyModule):
    def forward(self, X):
        self.out = torch.tanh(X)
        return self.out

And a way to flatten our embeddings from each token (up to block_size) into a single tensor:

class MyFlatten(MyModule):
    def __init__(self, start_dim=1, end_dim=-1):
        super().__init__()
        self.start_dim = start_dim
        self.end_dim = end_dim

    def forward(self, X):
        return X.flatten(self.start_dim, self.end_dim)
embedding = MyEmbedding(V, default_m)(X[:32])
embedding.shape
torch.Size([32, 3, 30])
MyFlatten()(embedding).shape
torch.Size([32, 90])

And then we just need to stack them together in a Sequential sequence of layers:

class MySequential(MyModule):
    def __init__(self, *args):
        super().__init__()
        self.layers = args

        self._parameters = [params for layer in self.layers for params in layer.parameters()]

    def __repr__(self):
        return f"{type(self).__name__}({self.layers})"

    def __getitem__(self, idx):
        return self.layers[idx]

    def forward(self, X):
        result = X
        for layer in self.layers:
            result = layer(result)
        return result

We can then build an MLP for a given embedding dimension, and set of hidden dimensions:

def get_mlp(m=default_m, hs=(default_h,), batch_norm=False, bias=False, V=V, block_size=block_size, activation_factory=lambda: MyTanh()):
    # First we embed the vectors and then flatten them
    layers = [MyEmbedding(V, m), MyFlatten()]
    
    # Then add the hidden layers
    in_sizes = [block_size * m] + list(hs)
    out_sizes = list(hs) + [V]
    for h_in, h_out in zip(in_sizes, out_sizes):
        layers.append(MyLinear(h_in, h_out, bias=bias))
        if batch_norm:
            layers.append(MyBatchNorm1d(num_features=h_out))
        layers.append(activation_factory())
    # Drop the last activation, since this is passed to Softmax
    layers.pop()

    return MySequential(*layers)

mlp = get_mlp(bias=True).requires_grad_()
mlp
MySequential((MyEmbedding(38, 30), MyFlatten, MyLinear(90, 200, bias=True), MyTanh, MyLinear(200, 38, bias=True)))

We can train this to get a similar results as before with fixed initialisation

losses, val_losses = train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1)

val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

mymlp_loss = val_loss_value[-1]

mymlp_loss, f'{mymlp_loss / fix_init_loss:0.2%} of fixed init loss'
(2.4358863830566406, '99.50% of fixed init loss')

Going deeper

Now we have the framework for building deeper MLPs let’s try to train and analyse some

Without correction

Let’s start with a plain 5 layer MLP without gain correction:

mlp = get_mlp(bias=True, m=10, hs=[100]*5).requires_grad_()
mlp
MySequential((MyEmbedding(38, 10), MyFlatten, MyLinear(30, 100, bias=True), MyTanh, MyLinear(100, 100, bias=True), MyTanh, MyLinear(100, 100, bias=True), MyTanh, MyLinear(100, 100, bias=True), MyTanh, MyLinear(100, 100, bias=True), MyTanh, MyLinear(100, 38, bias=True)))
with torch.no_grad():
    print('Initial Loss: %0.2f' % F.cross_entropy(mlp(X[:1000]), y[:1000]).item())
Initial Loss: 3.63

Let’s look at the initial activations

def show_layer(i, t):
    print(f'layer {i} ({layer}): mean {t.mean():0.2f}, std {t.std():0.2f}, saturated: {((t.abs() > 0.97) * 1.0).mean():0.2%}')
    hy, hx = torch.histogram(t, density=True)
    plt.plot(hx[:-1].detach(), hy.detach())
    legends.append(f'layer {i} ({layer})')
def show_layers(mlp, backward=False, classes=(MyTanh,), saturation_threshold = 0.97, figsize=(20,4), X=X_val[:32], y=y_val[:32]):
    preds = mlp(X)
    for layer in mlp:
        if hasattr(layer, 'out'):
            layer.out.retain_grad()
    loss = F.cross_entropy(input=preds, target=y)
    loss.backward()
    
    plt.figure(figsize=figsize) # width and height of the plot
    legends = []
    with torch.no_grad():
        for i, layer in enumerate(mlp):
            if isinstance(layer, classes):
                t = layer.out
                if backward:
                    t = t.grad
                print(f'layer {i} ({layer}): mean {t.mean():0.2f}, std {t.std():0.2f}, saturated: {((t.abs() > saturation_threshold) * 1.0).mean():0.2%}')
                hy, hx = torch.histogram(t, density=True)
                plt.plot(hx[:-1].detach(), hy.detach())
                legends.append(f'layer {i} ({layer})')

    
    plt.legend(legends);
    plt.title(('gradient' if backward else 'activation') + ' distribution')

    mlp.zero_grad()
show_layers(mlp)
layer 3 (MyTanh): mean 0.02, std 0.46, saturated: 0.00%
layer 5 (MyTanh): mean -0.00, std 0.27, saturated: 0.00%
layer 7 (MyTanh): mean 0.00, std 0.15, saturated: 0.00%
layer 9 (MyTanh): mean 0.01, std 0.10, saturated: 0.00%
layer 11 (MyTanh): mean 0.01, std 0.08, saturated: 0.00%

And gradients

show_layers(mlp, backward=True)
layer 3 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 5 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 7 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 9 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 11 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%

And the weights

def show_weights(mlp, figsize=(20, 4), skip_embedding_layer=True, X=X_val[:32], y=y_val[:32]):
    preds = mlp(X)
    for layer in mlp:
        if hasattr(layer, 'out'):
            layer.out.retain_grad()
    loss = F.cross_entropy(input=preds, target=y)
    loss.backward()
    
    plt.figure(figsize=figsize)
    legends = []
    for i, p in enumerate(mlp.parameters()):
        if skip_embedding_layer and i == 0:
            continue
        t = p.grad
        if p.ndim == 2:
            print('weight %10s | mean %+f | std %e | grad:data ratio %e' % (tuple(p.shape), t.mean(), t.std(), t.std() / p.std()))
            hy, hx = torch.histogram(t, density=True)
            plt.plot(hx[:-1].detach(), hy.detach())
            legends.append(f'{i} {tuple(p.shape)}')
    plt.legend(legends)
    plt.title('weights gradient distribution')

    mlp.zero_grad()
show_weights(mlp)
weight  (30, 100) | mean -0.000015 | std 8.145859e-04 | grad:data ratio 7.634363e-03
weight (100, 100) | mean -0.000002 | std 8.332252e-04 | grad:data ratio 1.447422e-02
weight (100, 100) | mean -0.000002 | std 9.669473e-04 | grad:data ratio 1.674323e-02
weight (100, 100) | mean +0.000016 | std 1.014291e-03 | grad:data ratio 1.749655e-02
weight (100, 100) | mean -0.000012 | std 1.328168e-03 | grad:data ratio 2.303185e-02
weight  (100, 38) | mean +0.000000 | std 3.048425e-03 | grad:data ratio 5.262388e-02

Let’s try training it for 1000 steps and see how it changes; we will track the update-to-data ratio during training:

class UDCallback():
    def __init__(self):
        self.ud = []

    def __call__(self, model, lr, step, n_step, **kwargs):
        lr_value = lr(step, n_step)
        self.ud.append([(lr_value * p.grad.std() / p.data.std()).log10().item() for p in model.parameters()])

    def plot(self, figsize=(20, 4)):
        ud = self.ud
        plt.figure(figsize=figsize)
        legends = []
        for i,p in enumerate(mlp.parameters()):
          if p.ndim == 2:
            plt.plot([ud[j][i] for j in range(len(ud))])
            legends.append('param %d' % i)
        plt.plot([0, len(ud)], [-3, -3], 'k') # these ratios should be ~1e-3, indicate on plot
        plt.legend(legends)
ud = UDCallback()

losses, val_losses = train(mlp, n_step=1_000, lr=lambda step, n_step: 0.1, callback=ud)

val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

The weights initially train at very different rates, but all but the output and embedding layer converge (to a slightly too high rate)

ud.plot()

The activations and gradients have largely sorted themselves out

show_layers(mlp)
layer 3 (MyTanh): mean 0.02, std 0.47, saturated: 0.00%
layer 5 (MyTanh): mean 0.00, std 0.33, saturated: 0.00%
layer 7 (MyTanh): mean -0.01, std 0.30, saturated: 0.00%
layer 9 (MyTanh): mean 0.02, std 0.33, saturated: 0.00%
layer 11 (MyTanh): mean 0.07, std 0.42, saturated: 0.00%

show_layers(mlp, backward=True)
layer 3 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 5 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 7 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 9 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 11 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%

And the weight gradients are more uniform

show_weights(mlp)
weight  (30, 100) | mean +0.000075 | std 7.344535e-03 | grad:data ratio 6.576340e-02
weight (100, 100) | mean -0.000005 | std 3.532875e-03 | grad:data ratio 5.843860e-02
weight (100, 100) | mean +0.000001 | std 2.672232e-03 | grad:data ratio 4.419371e-02
weight (100, 100) | mean -0.000003 | std 2.625167e-03 | grad:data ratio 4.310565e-02
weight (100, 100) | mean +0.000025 | std 3.583072e-03 | grad:data ratio 5.814888e-02
weight  (100, 38) | mean -0.000000 | std 1.289676e-02 | grad:data ratio 1.660140e-01

Even so it successfully trains

mlp = get_mlp(bias=True, m=10, hs=[100]*5).requires_grad_()

ud = UDCallback()

losses, val_losses = train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1, callback=ud)

val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

l5_loss = val_loss_value[-1]

l5_loss, f'{l5_loss / fix_init_loss:0.2%} of fixed init loss'
(2.4107613563537598, '98.48% of fixed init loss')

ud.plot()

show_weights(mlp)
weight  (30, 100) | mean -0.000254 | std 1.778142e-02 | grad:data ratio 5.134834e-02
weight (100, 100) | mean +0.000179 | std 1.043383e-02 | grad:data ratio 5.559371e-02
weight (100, 100) | mean -0.000025 | std 1.019966e-02 | grad:data ratio 5.705851e-02
weight (100, 100) | mean +0.000066 | std 1.106223e-02 | grad:data ratio 6.375358e-02
weight (100, 100) | mean +0.000054 | std 1.232356e-02 | grad:data ratio 8.169784e-02
weight  (100, 38) | mean -0.000000 | std 2.402255e-02 | grad:data ratio 1.571375e-01

However when we get to 20 layers it fails to train

mlp = get_mlp(bias=True, m=10, hs=[100]*20).requires_grad_()

ud = UDCallback()

losses, val_losses = train(mlp, n_step=5_000, lr=lambda step, n_step: 0.1, callback=ud, val_step=500)
val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

l20_loss = val_loss_value[-1]

l20_loss, f'{l20_loss / fix_init_loss:0.2%} of fixed init loss'
(3.050510883331299, '124.61% of fixed init loss')

The gradients are spread all over the place

ud.plot()

Fixing Gain

We can fix the gain within layers and reduce the gain in the final layer to get a better initialisation:

def add_gain(mlp, gain=5/3, output_gain=0.1, update_layers=(MyLinear,)):
    with torch.no_grad():
        for layer in mlp[:-1]:
            if isinstance(layer, update_layers):
                layer.weight *= gain
        mlp[-1].weight *= output_gain

mlp = get_mlp(bias=True, m=10, hs=[100]*5).requires_grad_()
add_gain(mlp, output_gain=0.1)
with torch.no_grad():
    print('Initial Loss: %0.2f' % F.cross_entropy(mlp(X[:1000]), y[:1000]).item())
Initial Loss: 3.65

The layers are more uniform now

show_layers(mlp)
layer 3 (MyTanh): mean -0.01, std 0.62, saturated: 2.72%
layer 5 (MyTanh): mean -0.02, std 0.47, saturated: 0.03%
layer 7 (MyTanh): mean 0.01, std 0.39, saturated: 0.00%
layer 9 (MyTanh): mean 0.01, std 0.34, saturated: 0.00%
layer 11 (MyTanh): mean -0.01, std 0.30, saturated: 0.00%

show_layers(mlp, backward=True)
layer 3 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 5 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 7 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 9 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 11 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%

The weights are more uniform except the output layer (because of the multiplication)

show_weights(mlp)
weight  (30, 100) | mean +0.000014 | std 3.260477e-04 | grad:data ratio 1.849884e-03
weight (100, 100) | mean +0.000001 | std 3.304202e-04 | grad:data ratio 3.432189e-03
weight (100, 100) | mean -0.000002 | std 3.086994e-04 | grad:data ratio 3.204230e-03
weight (100, 100) | mean +0.000001 | std 3.112402e-04 | grad:data ratio 3.216269e-03
weight (100, 100) | mean -0.000001 | std 3.279222e-04 | grad:data ratio 3.374243e-03
weight  (100, 38) | mean -0.000000 | std 9.346377e-03 | grad:data ratio 1.604743e+00

Let’s train it for a little while:

ud = UDCallback()

losses, val_losses = train(mlp, n_step=1_000, lr=lambda step, n_step: 0.1, callback=ud)

val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

The weight updates are similar to before but the weights move more in lockstep during the initial optimisation period

ud.plot()

The output layer gradients comes down across these iterations

show_weights(mlp, skip_embedding_layer=True)
weight  (30, 100) | mean -0.000254 | std 1.086286e-02 | grad:data ratio 6.070493e-02
weight (100, 100) | mean +0.000004 | std 7.106286e-03 | grad:data ratio 7.206615e-02
weight (100, 100) | mean +0.000033 | std 5.773973e-03 | grad:data ratio 5.865905e-02
weight (100, 100) | mean -0.000028 | std 5.451859e-03 | grad:data ratio 5.512737e-02
weight (100, 100) | mean +0.000017 | std 5.011663e-03 | grad:data ratio 5.035378e-02
weight  (100, 38) | mean -0.000000 | std 2.294768e-02 | grad:data ratio 3.940181e-01

And the activations and gradients still look good

show_layers(mlp)
layer 3 (MyTanh): mean -0.00, std 0.62, saturated: 2.94%
layer 5 (MyTanh): mean -0.04, std 0.51, saturated: 0.25%
layer 7 (MyTanh): mean -0.00, std 0.49, saturated: 0.34%
layer 9 (MyTanh): mean 0.01, std 0.51, saturated: 0.19%
layer 11 (MyTanh): mean 0.02, std 0.55, saturated: 0.66%

show_layers(mlp, backward=True)
layer 3 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 5 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 7 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 9 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 11 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%

For 5 layers we get a similar result as before

mlp = get_mlp(bias=True, m=10, hs=[100]*5).requires_grad_()
add_gain(mlp, output_gain=0.1)

ud = UDCallback()

losses, val_losses = train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1, callback=ud)

val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

l5_fix_loss = val_loss_value[-1]

l5_fix_loss, f'{l5_fix_loss / fix_init_loss:0.2%} of fixed init loss'
(2.4011518955230713, '98.08% of fixed init loss')

However at 20 layers we get the loss decreasing to a similar level as our 2 layer MLP

mlp = get_mlp(bias=True, m=10, hs=[100]*20).requires_grad_()
add_gain(mlp)

ud = UDCallback()

losses, val_losses = train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1, callback=ud, val_step=500)
val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

l20_fix_loss = val_loss_value[-1]

l20_fix_loss, f'{l20_fix_loss / fix_init_loss:0.2%} of fixed init loss'
(2.4759957790374756, '101.14% of fixed init loss')

There is still a bit more variance with the updates across layers in this deeper net, but with the better initialisation it is stable

ud.plot()

Batchnorm

Rather than being careful with initialisation we can use batch norm

mlp = get_mlp(bias=False, m=10, hs=[100]*5, batch_norm=True).requires_grad_()
#add_gain(mlp)

mlp
MySequential((MyEmbedding(38, 10), MyFlatten, MyLinear(30, 100, bias=False), MyBatchNorm1d(100, eps=1e-05, affine=True), MyTanh, MyLinear(100, 100, bias=False), MyBatchNorm1d(100, eps=1e-05, affine=True), MyTanh, MyLinear(100, 100, bias=False), MyBatchNorm1d(100, eps=1e-05, affine=True), MyTanh, MyLinear(100, 100, bias=False), MyBatchNorm1d(100, eps=1e-05, affine=True), MyTanh, MyLinear(100, 100, bias=False), MyBatchNorm1d(100, eps=1e-05, affine=True), MyTanh, MyLinear(100, 38, bias=False), MyBatchNorm1d(38, eps=1e-05, affine=True)))
show_layers(mlp)
layer 4 (MyTanh): mean -0.01, std 0.80, saturated: 32.00%
layer 7 (MyTanh): mean -0.00, std 0.83, saturated: 38.72%
layer 10 (MyTanh): mean -0.00, std 0.82, saturated: 34.44%
layer 13 (MyTanh): mean -0.01, std 0.82, saturated: 34.31%
layer 16 (MyTanh): mean -0.01, std 0.82, saturated: 36.94%

show_layers(mlp, backward=True)
layer 4 (MyTanh): mean 0.00, std 0.04, saturated: 0.00%
layer 7 (MyTanh): mean 0.00, std 0.03, saturated: 0.00%
layer 10 (MyTanh): mean 0.00, std 0.02, saturated: 0.00%
layer 13 (MyTanh): mean -0.00, std 0.01, saturated: 0.00%
layer 16 (MyTanh): mean 0.00, std 0.01, saturated: 0.00%

show_weights(mlp)
weight  (30, 100) | mean -0.006068 | std 3.976782e-01 | grad:data ratio 3.796145e+00
weight (100, 100) | mean +0.002900 | std 3.224502e-01 | grad:data ratio 5.585790e+00
weight (100, 100) | mean +0.002294 | std 2.273206e-01 | grad:data ratio 3.929580e+00
weight (100, 100) | mean +0.000491 | std 1.619342e-01 | grad:data ratio 2.798622e+00
weight (100, 100) | mean -0.000863 | std 1.270476e-01 | grad:data ratio 2.212344e+00
weight  (100, 38) | mean +0.002701 | std 1.532580e-01 | grad:data ratio 2.657254e+00

ud = UDCallback()

losses, val_losses = train(mlp, n_step=1_000, lr=lambda step, n_step: 0.1, callback=ud)

val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

The training looks more stable

ud.plot()

The saturation quickly reduces

show_layers(mlp)
layer 4 (MyTanh): mean -0.00, std 0.76, saturated: 20.53%
layer 7 (MyTanh): mean -0.01, std 0.68, saturated: 8.31%
layer 10 (MyTanh): mean 0.01, std 0.74, saturated: 16.50%
layer 13 (MyTanh): mean 0.00, std 0.73, saturated: 12.87%
layer 16 (MyTanh): mean -0.01, std 0.74, saturated: 12.53%

And we get a similar result when we train for a long time

mlp = get_mlp(batch_norm=True, bias=False, m=10, hs=[100]*5).requires_grad_()
add_gain(mlp, output_gain=0.1)

ud = UDCallback()

losses, val_losses = train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1, callback=ud)

val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

l5_bn_loss = val_loss_value[-1]

l5_bn_loss, f'{l5_bn_loss / fix_init_loss:0.2%} of fixed init loss'
(2.4429738521575928, '99.79% of fixed init loss')

ud.plot()

And it successfully trains a 20 layer network too

mlp = get_mlp(batch_norm=True, bias=False, m=10, hs=[100]*20).requires_grad_()
add_gain(mlp, output_gain=0.1)

ud = UDCallback()

losses, val_losses = train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1, callback=ud)

val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

l20_bn_loss = val_loss_value[-1]

l20_bn_loss, f'{l20_bn_loss / fix_init_loss:0.2%} of fixed init loss'
(2.5574026107788086, '104.47% of fixed init loss')

Using PyTorch nn Library

We can easily reuse Torch’s implementations which are likely more efficient

from torch.nn import Module, Linear, Embedding, Tanh, BatchNorm1d, Sequential, Flatten

def get_mlp(m=default_m, hs=(default_h,), batch_norm=False, bias=False, V=V, block_size=block_size, activation_factory=lambda: Tanh()):
    # First we embed the vectors and then flatten them
    layers = [Embedding(V, m), Flatten()]
    
    # Then add the hidden layers
    in_sizes = [block_size * m] + list(hs)
    out_sizes = list(hs) + [V]
    for h_in, h_out in zip(in_sizes, out_sizes):
        layers.append(Linear(h_in, h_out, bias=bias))
        if batch_norm:
            layers.append(BatchNorm1d(num_features=h_out))
        layers.append(activation_factory())
    # Drop the last activation, since this is passed to Softmax
    layers.pop()

    return Sequential(*layers)

mlp = get_mlp(bias=True, m=10, hs=[100]*5).requires_grad_()
mlp
Sequential(
  (0): Embedding(38, 10)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=30, out_features=100, bias=True)
  (3): Tanh()
  (4): Linear(in_features=100, out_features=100, bias=True)
  (5): Tanh()
  (6): Linear(in_features=100, out_features=100, bias=True)
  (7): Tanh()
  (8): Linear(in_features=100, out_features=100, bias=True)
  (9): Tanh()
  (10): Linear(in_features=100, out_features=100, bias=True)
  (11): Tanh()
  (12): Linear(in_features=100, out_features=38, bias=True)
)

However we can’t show the layers the same way as before because PyTorch doesn’t store results in .out

try:
    show_layers(mlp, classes=(Tanh,))
except AttributeError as e:
    print(e)
'Tanh' object has no attribute 'out'
<Figure size 2000x400 with 0 Axes>

However PyTorch provides a mechanism to modify the behaviour of a module without editing the source code: hooks.

We can register a forward hook to capture the output:

def log_output(module, args, output):
    module.out = output

for layer in mlp:
    if isinstance(layer, Tanh):
        layer.register_forward_hook(log_output)

And then use our functions as before:

show_layers(mlp, classes=(Tanh,))
layer 3 (Tanh()): mean -0.03, std 0.49, saturated: 0.37%
layer 5 (Tanh()): mean -0.01, std 0.26, saturated: 0.00%
layer 7 (Tanh()): mean -0.01, std 0.16, saturated: 0.00%
layer 9 (Tanh()): mean -0.01, std 0.11, saturated: 0.00%
layer 11 (Tanh()): mean -0.02, std 0.08, saturated: 0.00%

show_layers(mlp, backward=True, classes=(Tanh,))
layer 3 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 5 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 7 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 9 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 11 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%

show_weights(mlp)
weight  (100, 30) | mean -0.000007 | std 1.011561e-03 | grad:data ratio 9.519964e-03
weight (100, 100) | mean -0.000002 | std 1.124272e-03 | grad:data ratio 1.947545e-02
weight (100, 100) | mean -0.000002 | std 1.041845e-03 | grad:data ratio 1.801275e-02
weight (100, 100) | mean -0.000002 | std 1.207655e-03 | grad:data ratio 2.074841e-02
weight (100, 100) | mean -0.000012 | std 1.380317e-03 | grad:data ratio 2.384296e-02
weight  (38, 100) | mean -0.000000 | std 3.135142e-03 | grad:data ratio 5.340258e-02

We can fix the weights as before, updating Linear (rather than MyLinear) layers:

add_gain(mlp, output_gain=0.1, update_layers=(Linear,))
show_layers(mlp, classes=(Tanh,))
layer 3 (Tanh()): mean -0.03, std 0.64, saturated: 5.59%
layer 5 (Tanh()): mean -0.00, std 0.47, saturated: 0.12%
layer 7 (Tanh()): mean -0.01, std 0.39, saturated: 0.00%
layer 9 (Tanh()): mean -0.01, std 0.33, saturated: 0.00%
layer 11 (Tanh()): mean -0.01, std 0.30, saturated: 0.00%

show_layers(mlp, backward=True, classes=(Tanh,))
layer 3 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 5 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 7 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 9 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 11 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%

show_weights(mlp)
weight  (100, 30) | mean -0.000002 | std 4.227582e-04 | grad:data ratio 2.387187e-03
weight (100, 100) | mean -0.000001 | std 4.307418e-04 | grad:data ratio 4.476971e-03
weight (100, 100) | mean +0.000000 | std 3.819067e-04 | grad:data ratio 3.961735e-03
weight (100, 100) | mean +0.000001 | std 3.799068e-04 | grad:data ratio 3.916249e-03
weight (100, 100) | mean -0.000004 | std 3.589212e-04 | grad:data ratio 3.719902e-03
weight  (38, 100) | mean +0.000000 | std 9.675195e-03 | grad:data ratio 1.648029e+00

And training is exactly the same:

ud = UDCallback()

losses, val_losses = train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1, callback=ud)

val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

l5_torch_loss = val_loss_value[-1]

l5_torch_loss, f'{l5_torch_loss / l5_fix_loss:0.2%} of l5 fixed loss'
(2.4006731510162354, '99.98% of l5 fixed loss')

We can plot the update dynamics and it looks similar to before

ud.plot()

Batchnorm

Similarly we can train with batchnorm

mlp = get_mlp(batch_norm=True, bias=False, m=10, hs=[100]*5).requires_grad_()

ud = UDCallback()

losses, val_losses = train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1, callback=ud)

val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

l5_torch_bn_loss = val_loss_value[-1]

l5_torch_bn_loss, f'{l5_torch_bn_loss / l5_bn_loss:0.2%} of l5 batchnorm loss'
(2.3883166313171387, '97.76% of l5 batchnorm loss')

Training a 5 layer network

Put it all together let’s see how low we can make the loss

mlp = get_mlp(bias=True, m=default_m, hs=[default_h]*5).requires_grad_()
add_gain(mlp, update_layers=(Linear,))

Let’s check the initialisation:

hooks = []

for layer in mlp:
    if isinstance(layer, Tanh):
        hooks.append(layer.register_forward_hook(log_output))
show_layers(mlp, classes=(Tanh,))
layer 3 (Tanh()): mean 0.00, std 0.61, saturated: 2.58%
layer 5 (Tanh()): mean -0.00, std 0.46, saturated: 0.02%
layer 7 (Tanh()): mean -0.01, std 0.37, saturated: 0.00%
layer 9 (Tanh()): mean 0.01, std 0.32, saturated: 0.00%
layer 11 (Tanh()): mean -0.00, std 0.28, saturated: 0.00%

show_layers(mlp, classes=(Tanh,), backward=True)
layer 3 (Tanh()): mean -0.00, std 0.00, saturated: 0.00%
layer 5 (Tanh()): mean -0.00, std 0.00, saturated: 0.00%
layer 7 (Tanh()): mean -0.00, std 0.00, saturated: 0.00%
layer 9 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 11 (Tanh()): mean -0.00, std 0.00, saturated: 0.00%

Remove the hooks so we don’t need to store any unnecessary outputs

for hook in hooks:
    hook.remove()

We’ll train it for a lot longer and get a slightly lower loss

ud = UDCallback()

losses, val_losses = train(mlp, n_step=200_000, lr=lambda step, n_step: 0.1, callback=ud)
val_loss_step, val_loss_value = zip(*val_losses)

plt.plot(val_loss_step, val_loss_value)

aloss = val_loss_value[-1]

aloss, f'{aloss / fix_init_loss:0.2%} of fixed init loss'
(2.3479623794555664, '95.91% of fixed init loss')

At the end the training has converged

val_loss_ema = [val_loss_value[0]]
momentum = 0.01

for v in val_loss_value[1:]:
    val_loss_ema.append(val_loss_ema[-1] * (1 - momentum) + v * momentum)

plt.plot(val_loss_step[-200:], val_loss_value[-200:])
plt.plot(val_loss_step[-200:], val_loss_ema[-200:])

These are looking slightly more human readable, and some are good, but in general don’t have long range coherence, suggesting that context length is the bottleneck.

for _ in range(20):
    print(sample(mlp))
star
worldonteslangendownpolittersing
wegamepiand_ete
bey_babybugramberdenvanceanimedynews
wresiranugue
disabamazmongonpraxgifsonavioloningwael
windian_botlauker
digifs
kawallorestages
mildlanenoumusicess
slustsurveeupjobbyhbosfronthavirthoppromartyle
eding
blianquely
lone
djerkfries
blacebakinds
knerdups
mints
thedomniada
tecons

What next?

We have looked closer at how the weights change during training, and how initialisation and batch norm can keep the weights in a better range during training. We’ve also changed everything into pure PyTorch and added some instrumentation for checking how weights and gradients change over time.

What we haven’t done is substantially reduced the loss; for this we’re likely going to need more than 3 characters of context