Batch Normalization#
La Batch Normalization (ou normalisation par lot) a Ă©tĂ© introduite en 2015 dans lâarticle Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. Elle a eu un impact majeur dans le domaine du Deep Learning. Aujourdâhui, la normalisation est utilisĂ©e presque systĂ©matiquement, quâil sâagisse de BatchNorm, LayerNorm ou GroupNorm (et dâautres).
LâidĂ©e de la BatchNorm est simple et liĂ©e au notebook prĂ©cĂ©dent. On cherche Ă obtenir des preactivations suivant une distribution gaussienne Ă chaque couche du rĂ©seau. On a vu quâune bonne initialisation permet dâavoir ce comportement, mais elle nâest pas toujours Ă©vidente, surtout avec de nombreuses couches diffĂ©rentes.
La BatchNorm normalise les preactivations par rapport Ă la dimension du batch avant de les passer dans les fonctions dâactivation. Cela garantit une distribution gaussienne Ă chaque Ă©tape.
Cette normalisation nâaffecte pas lâoptimisation car il sâagit dâune fonction dĂ©rivable.
Implémentation#
Reprise du code#
On va reprendre le code du notebook précédent pour implémenter la batch normalization.
import torch
import torch.nn.functional as F
%matplotlib inline
words = open('../05_NLP/prenoms.txt', 'r').read().splitlines()
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
block_size = 3 # Contexte
def build_dataset(words):
X, Y = [], []
for w in words:
context = [0] * block_size
for ch in w + '.':
ix = stoi[ch]
X.append(context)
Y.append(ix)
context = context[1:] + [ix]
X = torch.tensor(X)
Y = torch.tensor(Y)
print(X.shape, Y.shape)
return X, Y
import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))
Xtr, Ytr = build_dataset(words[:n1]) # 80%
Xdev, Ydev = build_dataset(words[n1:n2]) # 10%
Xte, Yte = build_dataset(words[n2:]) # 10%
torch.Size([180834, 3]) torch.Size([180834])
torch.Size([22852, 3]) torch.Size([22852])
torch.Size([22639, 3]) torch.Size([22639])
embed_dim=10 # Dimension de l'embedding de C
hidden_dim=200 # Dimension de la couche cachée
C = torch.randn((46, embed_dim))
W1 = torch.randn((block_size*embed_dim, hidden_dim))*0.01 # On initialise les poids Ă une petite valeur
b1 = torch.randn(hidden_dim) *0 # On initialise les biais Ă 0
W2 = torch.randn((hidden_dim, 46))*0.01
b2 = torch.randn(46)*0
parameters = [C, W1, b1, W2, b2]
for p in parameters:
p.requires_grad = True
Voici notre code de propagation avant :
batch_size = 32
ix = torch.randint(0, Xtr.shape[0], (batch_size,))
# Forward
Xb, Yb = Xtr[ix], Ytr[ix]
emb = C[Xb]
embcat = emb.view(emb.shape[0], -1)
hpreact = embcat @ W1 + b1
h = torch.tanh(hpreact)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Yb)
Implémentation de la BatchNorm#
DâaprĂšs lâarticle, voici les informations :

Dans un premier temps, il sâagit de normaliser.
Pour cela, on calcule la moyenne et lâĂ©cart type de hpreact puis on normalise avec ces valeurs :
epsilon=1e-6
hpreact_mean = hpreact.mean(dim=0, keepdim=True)
hpreact_std= hpreact.std(dim=0, keepdim=True)
hpreact_norm = (hpreact - hpreact_mean) / (hpreact_std+epsilon)
On peut maintenant intégrer cette normalisation à la propagation avant.
Avant cela, notons quâon nâa pas encore implĂ©mentĂ© la partie scale and shift :

Ă quoi ça sert ? : La normalisation confine les poids Ă prendre des valeurs dâune gaussienne centrĂ©e rĂ©duite. Cela rĂ©duit les capacitĂ©s dâexpression du modĂšle. Les paramĂštres apprenables \(\gamma\) et \(\beta\) permettent de contourner ce problĂšme en ajoutant un shift avec \(\beta\) et un scale avec \(\gamma\).
Comme il sâagit de paramĂštres apprenables, on doit aussi les ajouter aux paramĂštres du modĂšle :
C = torch.randn((46, embed_dim))
W1 = torch.randn((block_size*embed_dim, hidden_dim))*0.01 # On initialise les poids Ă une petite valeur
b1 = torch.randn(hidden_dim) *0 # On initialise les biais Ă 0
W2 = torch.randn((hidden_dim, 46))*0.01
b2 = torch.randn(46)*0
# ParamĂštres de batch normalization
bngain = torch.ones((1, hidden_dim))
bnbias = torch.zeros((1, hidden_dim))
parameters = [C, W1, b1, W2, b2, bngain, bnbias]
for p in parameters:
p.requires_grad = True
Et en propagation avant, on aura donc :
batch_size = 32
ix = torch.randint(0, Xtr.shape[0], (batch_size,))
# Forward
Xb, Yb = Xtr[ix], Ytr[ix]
emb = C[Xb]
embcat = emb.view(emb.shape[0], -1)
hpreact = embcat @ W1 + b1
# Batch normalization
bnmean = hpreact.mean(0, keepdim=True)
bnstd = hpreact.std(0, keepdim=True)
hpreact = bngain * (hpreact - bnmean) / bnstd + bnbias
h = torch.tanh(hpreact)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Yb)
Le problĂšme de la Batch Normalization#
En y réfléchissant un peu, on peut identifier des problÚmes potentiels liés à la BatchNorm :
Un exemple est impactĂ© par les autres Ă©lĂ©ments du batch : Normaliser selon la dimension du batch signifie que les valeurs de chaque exemple au sein du batch sont influencĂ©es par les autres exemples. Cela pourrait sembler problĂ©matique, mais en pratique, câest plutĂŽt une bonne chose. Lâutilisation de batchs alĂ©atoires Ă chaque Ă©poque permet une rĂ©gularisation, ce qui rĂ©duit le risque de overfit sur les donnĂ©es. NĂ©anmoins, si on veut Ă©viter ce problĂšme, on peut utiliser dâautres mĂ©thodes de normalisation qui ne normalisent pas selon la dimension du batch. En pratique, la BatchNorm est encore largement utilisĂ©e car elle fonctionne trĂšs bien empiriquement.
Phase de test sur un seul Ă©lĂ©ment : Pendant lâentraĂźnement, chaque Ă©lĂ©ment est influencĂ© par les autres Ă©lĂ©ments de son batch. Cependant, en phase dâinfĂ©rence, lorsquâon utilise le modĂšle sur un seul Ă©lĂ©ment, on ne peut plus appliquer la BatchNorm. Câest un problĂšme car on veut Ă©viter un comportement diffĂ©rent pendant lâentraĂźnement et lâinfĂ©rence.
Pour résoudre ce problÚme, on a deux solutions :
On peut calculer la moyenne et la variance sur lâensemble des Ă©lĂ©ments Ă la fin de lâentraĂźnement et utiliser ces valeurs. En pratique, on ne veut pas faire une itĂ©ration supplĂ©mentaire sur lâensemble du dataset juste pour ça, donc personne ne fait comme ça.
Une autre solution consiste Ă mettre Ă jour la moyenne et la variance tout au long de lâentraĂźnement grĂące Ă un EMA (exponential moving average). Ă la fin de lâentraĂźnement, on aura une bonne approximation de la moyenne et de la variance de lâensemble des Ă©lĂ©ments dâentraĂźnement.
En pratique, on peut lâimplĂ©menter comme ça en Python :
C = torch.randn((46, embed_dim))
W1 = torch.randn((block_size*embed_dim, hidden_dim))*0.01 # On initialise les poids Ă une petite valeur
b1 = torch.randn(hidden_dim) *0 # On initialise les biais Ă 0
W2 = torch.randn((hidden_dim, 46))*0.01
b2 = torch.randn(46)*0
# ParamĂštres de batch normalization
bngain = torch.ones((1, hidden_dim))
bnbias = torch.zeros((1, hidden_dim))
bnmean_running = torch.zeros((1, hidden_dim))
bnstd_running = torch.ones((1, hidden_dim))
parameters = [C, W1, b1, W2, b2, bngain, bnbias]
for p in parameters:
p.requires_grad = True
batch_size = 32
ix = torch.randint(0, Xtr.shape[0], (batch_size,))
# Forward
Xb, Yb = Xtr[ix], Ytr[ix]
emb = C[Xb]
embcat = emb.view(emb.shape[0], -1)
hpreact = embcat @ W1 + b1
# Batch normalization
bnmeani = hpreact.mean(0, keepdim=True)
bnstdi = hpreact.std(0, keepdim=True)
hpreact = bngain * (hpreact - bnmeani) / bnstdi + bnbias
with torch.no_grad(): # On ne veut pas calculer de gradient pour ces opérations
bnmean_running = 0.999 * bnmean_running + 0.001 * bnmeani
bnstd_running = 0.999 * bnstd_running + 0.001 * bnstdi
h = torch.tanh(hpreact)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Yb)
Note : Dans notre implĂ©mentation, on a choisi 0.001 pour notre EMA. Dans la couche BatchNorm de PyTorch, ce paramĂštre est dĂ©fini par momentum et sa valeur par dĂ©faut est 0.1. En pratique, le choix de cette valeur dĂ©pend de la taille du batch par rapport Ă la taille du jeu de donnĂ©es dâentraĂźnement. Pour un gros batch avec un petit jeu de donnĂ©es, on peut prendre 0.1, par exemple. Pour un petit batch avec un gros jeu de donnĂ©es, on prend plutĂŽt une plus petite valeur.
Testons maintenant lâentraĂźnement de notre modĂšle pour vĂ©rifier que la couche fonctionne. Pour ce petit modĂšle, on nâaura pas de diffĂ©rence de performance.
lossi = []
max_steps = 200000
for i in range(max_steps):
ix = torch.randint(0, Xtr.shape[0], (batch_size,))
Xb, Yb = Xtr[ix], Ytr[ix]
emb = C[Xb]
embcat = emb.view(emb.shape[0], -1)
hpreact = embcat @ W1 + b1
# Batch normalization
bnmeani = hpreact.mean(0, keepdim=True)
bnstdi = hpreact.std(0, keepdim=True)
hpreact = bngain * (hpreact - bnmeani) / bnstdi + bnbias
with torch.no_grad(): # On ne veut pas calculer de gradient pour ces opérations
bnmean_running = 0.999 * bnmean_running + 0.001 * bnmeani
bnstd_running = 0.999 * bnstd_running + 0.001 * bnstdi
h = torch.tanh(hpreact)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Yb)
for p in parameters:
p.grad = None
loss.backward()
lr = 0.1 if i < 100000 else 0.01
for p in parameters:
p.data += -lr * p.grad
if i % 10000 == 0:
print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')
lossi.append(loss.log10().item())
0/ 200000: 3.8241
10000/ 200000: 1.9756
20000/ 200000: 2.7151
30000/ 200000: 2.3287
40000/ 200000: 2.1411
50000/ 200000: 2.3207
60000/ 200000: 2.3250
70000/ 200000: 2.0320
80000/ 200000: 2.0615
90000/ 200000: 2.2468
100000/ 200000: 2.2081
110000/ 200000: 2.1418
120000/ 200000: 1.9665
130000/ 200000: 1.8572
140000/ 200000: 2.0577
150000/ 200000: 2.1804
160000/ 200000: 1.8604
170000/ 200000: 1.9810
180000/ 200000: 1.8228
190000/ 200000: 1.9977
Considérations supplémentaires#
Biais : La batch norm normalise les preactivations des poids. Cette normalisation annule le biais (car celui-ci dĂ©cale la distribution, alors que nous la recentrons). Lorsquâon utilise la BatchNorm, on peut se passer du biais. En pratique, si on laisse un biais, ça ne pose pas de problĂšme, mais câest un paramĂštre du rĂ©seau qui sera inutile.
Placement de la BatchNorm : DâaprĂšs ce quâon a vu, il est logique de placer la BatchNorm avant la fonction dâactivation. En pratique, certains prĂ©fĂšrent la placer aprĂšs la couche dâactivation, donc ne soyez pas Ă©tonnĂ© si vous tombez sur ça dans la littĂ©rature ou dans un code.
Autres normalisation#
Nous allons faire un tour rapide des autres normalisations utilisĂ©es pour lâentraĂźnement des rĂ©seaux de neurones.

Figure extraite de lâarticle
Layer Normalization : Cette couche de normalisation est Ă©galement trĂšs frĂ©quemment utilisĂ©e, notamment dans les modĂšles de langage (GPT, Llama). Il sâagit de normaliser sur lâensemble des activations de la couche plutĂŽt que sur lâaxe du batch. Dans notre implĂ©mentation, cela reviendrait simplement Ă changer :
# Batch normalization
bnmeani = hpreact.mean(0, keepdim=True)
bnstdi = hpreact.std(0, keepdim=True)
# Layer normalization
bnmeani = hpreact.mean(1, keepdim=True)
bnstdi = hpreact.std(1, keepdim=True)
Instance Normalization : Cette couche normalise les activations sur chaque canal de chaque élément indépendamment.
Group Normalization : Cette couche est une sorte de fusion entre la LayerNorm et lâInstanceNorm, puisquâon calcule la normalisation sur des groupes de canaux (si la taille dâun groupe vaut 1, câest lâInstanceNorm et si la taille dâun groupe vaut C, câest la LayerNorm)