Activations et initialisations#

Dans ce cours, nous allons reprendre le modĂšle Fully Connected prĂ©sentĂ© dans le cours 5 sur les NLP. Nous allons Ă©tudier le comportement des activations tout au long du rĂ©seau lors de son initialisation. Ce cours s’inspire du cours d’Andrej Karpathy, intitulĂ© Building makemore Part 3: Activations & Gradients, BatchNorm.

Les réseaux de neurones présentent plusieurs avantages :

  • Ils sont trĂšs flexibles et peuvent rĂ©soudre de nombreux problĂšmes.

  • Ils sont assez simples Ă  implĂ©menter.

Cependant, il est souvent complexe de les optimiser, surtout lorsqu’il s’agit de rĂ©seaux profonds.

Reprise du code#

Nous reprenons le code du notebook 3 du cours 5 sur les NLP.

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
words = open('../05_NLP/prenoms.txt', 'r').read().splitlines()
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

Pour des raisons pĂ©dagogiques, nous n’utiliserons pas le dataset et le dataloader de PyTorch. Nous Ă©valuerons le loss au dĂ©but de l’entraĂźnement aprĂšs le premier batch. Globalement, cela fonctionne de la mĂȘme maniĂšre, sauf que nous prenons un batch alĂ©atoire Ă  chaque itĂ©ration au lieu de parcourir l’ensemble du dataset Ă  chaque epoch.

block_size = 3 # Contexte

def build_dataset(words):  
  X, Y = [], []
  
  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix] 

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%
torch.Size([180834, 3]) torch.Size([180834])
torch.Size([22852, 3]) torch.Size([22852])
torch.Size([22639, 3]) torch.Size([22639])
embed_dim=10 # Dimension de l'embedding de C
hidden_dim=200 # Dimension de la couche cachée

C = torch.randn((46, embed_dim))
W1 = torch.randn((block_size*embed_dim, hidden_dim))
b1 = torch.randn(hidden_dim)
W2 = torch.randn((hidden_dim, 46))
b2 = torch.randn(46)
parameters = [C, W1, b1, W2, b2]
for p in parameters:
  p.requires_grad = True
max_steps = 200000
batch_size = 32
lossi = []

for i in range(max_steps):
  # Permet de construire un mini-batch
  ix = torch.randint(0, Xtr.shape[0], (batch_size,))
  
  # Forward
  Xb, Yb = Xtr[ix], Ytr[ix] 
  emb = C[Xb] 
  embcat = emb.view(emb.shape[0], -1)
  hpreact = embcat @ W1 + b1 

  h = torch.tanh(hpreact) 
  logits = h @ W2 + b2 
  loss = F.cross_entropy(logits, Yb)
  
  # Retropropagation
  for p in parameters:
    p.grad = None
  
  loss.backward()
  # Mise Ă  jour des paramĂštres
  lr = 0.1 if i < 100000 else 0.01 # On descend le learning rate d'un facteur 10 aprÚs 100000 itérations
  for p in parameters:
    p.data += -lr * p.grad

  if i % 10000 == 0:
    print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')
  lossi.append(loss.log10().item())
      0/ 200000: 21.9772
  10000/ 200000: 2.9991
  20000/ 200000: 2.5258
  30000/ 200000: 1.9657
  40000/ 200000: 2.4326
  50000/ 200000: 1.7670
  60000/ 200000: 2.1324
  70000/ 200000: 2.4160
  80000/ 200000: 2.2237
  90000/ 200000: 2.3905
 100000/ 200000: 1.9304
 110000/ 200000: 2.1710
 120000/ 200000: 2.3444
 130000/ 200000: 2.0970
 140000/ 200000: 1.8623
 150000/ 200000: 1.9792
 160000/ 200000: 2.4602
 170000/ 200000: 2.0968
 180000/ 200000: 2.0466
 190000/ 200000: 2.3746
plt.plot(lossi)
[<matplotlib.lines.Line2D at 0x7f028467b990>]
../_images/f9f24259f59ae48dd736be57df89e1b176e2c4aa450e77192ec2f594d4986101.png

Il y a beaucoup de “bruit” car nous calculons le loss Ă  chaque fois sur des petits batchs par rapport Ă  l’ensemble des donnĂ©es d’entraĂźnement.

Loss anormalement Ă©levĂ© Ă  l’initialisation#

L’entraĂźnement se dĂ©roule correctement. Cependant, on remarque quelque chose d’étrange : le loss au dĂ©but de l’entraĂźnement est anormalement Ă©levĂ©. On s’attendrait Ă  obtenir une valeur correspondant Ă  un cas oĂč chaque lettre a une probabilitĂ© uniforme d’apparition (soit \(\frac{1}{46}\)).

Dans ce cas, le negative log likelihood serait : \(-ln(\frac{1}{46})=3.83\)

Il serait donc logique d’obtenir une valeur de cet ordre lors du premier calcul du loss.

Petit exemple illustrant le problĂšme#

Pour comprendre ce qui se passe, utilisons un petit exemple et observons les valeurs de loss en fonction de l’initialisation. Imaginons que tous les poids dans logits sont initialisĂ©s Ă  0. Dans ce cas, on obtiendrait des probabilitĂ©s uniformes.

logits=torch.tensor([0.0,0.0,0.0,0.0])
probs=torch.softmax(logits,dim=0)
loss= -probs[1].log()
probs,loss
(tensor([0.2500, 0.2500, 0.2500, 0.2500]), tensor(1.3863))

Cependant, il n’est pas conseillĂ© d’initialiser les poids d’un rĂ©seau de neurones Ă  0. Nous avons utilisĂ© une initialisation alĂ©atoire basĂ©e sur une gaussienne centrĂ©e rĂ©duite.

logits=torch.randn(4)
probs=torch.softmax(logits,dim=0)
loss= -probs[1].log()
probs,loss
(tensor([0.3143, 0.0607, 0.3071, 0.3178]), tensor(2.8012))

On voit rapidement le problĂšme : l’alĂ©atoire de la gaussienne fait pencher la balance d’un cĂŽtĂ© ou de l’autre (vous pouvez lancer plusieurs fois le code prĂ©cĂ©dent pour vous en assurer).

Alors, que peut-on faire ? Il suffit de multiplier notre vecteur logit par une petite valeur pour diminuer la valeur initiale des poids et rendre le softmax plus uniforme.

logits=torch.randn(4)*0.01
probs=torch.softmax(logits,dim=0)
loss= -probs[1].log()
probs,loss
(tensor([0.2489, 0.2523, 0.2495, 0.2493]), tensor(1.3772))

On obtient, Ă  peu de choses prĂšs, le mĂȘme loss que pour des probabilitĂ©s uniformes.

Note : En revanche, on peut initialiser la valeur du biais Ă  zĂ©ro, car cela n’a pas de sens d’avoir un biais positif ou nĂ©gatif Ă  l’initialisation.

Entraünement avec l’ajustement de l’initialisation#

Reprenons le code prĂ©cĂ©dent, mais avec les nouvelles valeurs d’initialisation.

C = torch.randn((46, embed_dim))
W1 = torch.randn((block_size*embed_dim, hidden_dim))*0.01 # On initialise les poids Ă  une petite valeur
b1 = torch.randn(hidden_dim) *0 # On initialise les biais Ă  0
W2 = torch.randn((hidden_dim, 46))*0.01
b2 = torch.randn(46)*0 
parameters = [C, W1, b1, W2, b2]
for p in parameters:
  p.requires_grad = True
lossi = []

for i in range(max_steps):
  ix = torch.randint(0, Xtr.shape[0], (batch_size,))
  Xb, Yb = Xtr[ix], Ytr[ix] 
  emb = C[Xb] 
  embcat = emb.view(emb.shape[0], -1)
  hpreact = embcat @ W1 + b1 
  h = torch.tanh(hpreact) 
  logits = h @ W2 + b2 
  loss = F.cross_entropy(logits, Yb)
  
  for p in parameters:
    p.grad = None
  loss.backward()
  lr = 0.1 if i < 100000 else 0.01 
  for p in parameters:
    p.data += -lr * p.grad
  if i % 10000 == 0:
    print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')
  lossi.append(loss.log10().item())
      0/ 200000: 3.8304
  10000/ 200000: 2.4283
  20000/ 200000: 2.0651
  30000/ 200000: 2.1124
  40000/ 200000: 2.3158
  50000/ 200000: 2.2752
  60000/ 200000: 2.1887
  70000/ 200000: 2.1783
  80000/ 200000: 1.8120
  90000/ 200000: 2.3178
 100000/ 200000: 2.0973
 110000/ 200000: 1.8992
 120000/ 200000: 1.6917
 130000/ 200000: 2.2747
 140000/ 200000: 1.8054
 150000/ 200000: 2.3569
 160000/ 200000: 2.4231
 170000/ 200000: 2.0711
 180000/ 200000: 2.1379
 190000/ 200000: 1.8419
plt.plot(lossi)
[<matplotlib.lines.Line2D at 0x7f0278310110>]
../_images/274b51e7627a1763fa99e1b031ad6f827df67595fbd791710bcfa568811a447d.png

Nous avons maintenant une courbe de loss qui ne commence pas Ă  une valeur aberrante, ce qui accĂ©lĂšre l’optimisation.

Autre problĂšme#

On peut penser qu’un loss Ă©levĂ© n’est pas forcĂ©ment un problĂšme. Cependant, une mauvaise initialisation des poids peut poser d’autres problĂšmes.

ConsidĂ©rons la premiĂšre itĂ©ration de l’entraĂźnement avec des valeurs initialisĂ©es sans le facteur 0.01.

C = torch.randn((46, embed_dim))
W1 = torch.randn((block_size*embed_dim, hidden_dim)) 
b1 = torch.randn(hidden_dim) 
W2 = torch.randn((hidden_dim, 46))
b2 = torch.randn(46)
parameters = [C, W1, b1, W2, b2]
for p in parameters:
  p.requires_grad = True
ix = torch.randint(0, Xtr.shape[0], (batch_size,))
Xb, Yb = Xtr[ix], Ytr[ix] 
emb = C[Xb] 
embcat = emb.view(emb.shape[0], -1)
hpreact = embcat @ W1 + b1 
h = torch.tanh(hpreact) 
logits = h @ W2 + b2 
loss = F.cross_entropy(logits, Yb)
  
for p in parameters:
  p.grad = None
loss.backward()

Nous regardons l’histogramme des valeurs aprùs la fonction d’activation tanh.

plt.hist(h.view(-1).tolist(),50);
../_images/26b88fddfbb66837e96c8c3d11a891c516b52dd14c4e49629d1b85562c160f6a.png

On observe que la majorité des valeurs sont autour de 1 ou -1.

En quoi cela pose-t-il problĂšme ? Lors du calcul du gradient, avec la rĂšgle de la chaĂźne, on multiplie les gradients des diffĂ©rentes Ă©tapes de calcul. La dĂ©rivĂ©e de la fonction tanh est : \(tanh'(t)= 1 - t^2\) Si les valeurs de \(t\) sont Ă  1 ou -1, alors le gradient sera extrĂȘmement faible (jamais nul, car c’est une asymptote). Cela signifie que le gradient ne se propage pas, et donc l’optimisation ne peut pas fonctionner de maniĂšre optimale au dĂ©but de l’entraĂźnement.

On peut visualiser les valeurs de chaque neurone.

plt.figure(figsize=(20,10))
plt.imshow(h.abs()>0.99,cmap='gray',interpolation='nearest')
<matplotlib.image.AxesImage at 0x7f02780ae550>
../_images/f1d34dca2ef3cb2f2f05716107320504d6500d05b81000257cf6dd8446104fa3.png

Chaque point blanc correspond à un neurone dont le gradient est à peu prÚs égal à 0.

Neurone mort : Si une de ces colonnes Ă©tait entiĂšrement blanche, cela signifierait que le neurone ne s’active sur aucun Ă©lĂ©ment (du batch). Cela signifie qu’il s’agit d’un neurone inutile, qui n’aura aucun impact sur le rĂ©sultat et qu’on ne peut pas optimiser (sur les valeurs prĂ©sentes dans ce batch).

Notes :

  • Ce type de comportement n’est pas exclusif Ă  la tanh : la sigmoid et la ReLU peuvent avoir le mĂȘme problĂšme.

  • Le problĂšme ne nous a pas empĂȘchĂ© d’entraĂźner notre rĂ©seau correctement, car il s’agit d’un petit modĂšle. Sur des rĂ©seaux plus profonds, c’est un gros problĂšme, et il est conseillĂ© de vĂ©rifier les activations de votre rĂ©seau aux diffĂ©rentes Ă©tapes.

  • Les neurones morts peuvent apparaĂźtre Ă  l’initialisation, mais aussi pendant l’entraĂźnement si le learning rate est trop Ă©levĂ©, par exemple.

Comment résoudre ce problÚme ?#

Heureusement, ce problĂšme peut se rĂ©soudre exactement de la mĂȘme maniĂšre que le problĂšme du loss trop Ă©levĂ©. Pour nous en assurer, regardons les valeurs des activations et les neurones inactifs Ă  l’initialisation avec nos nouvelles valeurs.

C = torch.randn((46, embed_dim))
W1 = torch.randn((block_size*embed_dim, hidden_dim)) *0.01# On initialise les poids Ă  une petite valeur
b1 = torch.randn(hidden_dim) *0 # On initialise les biais Ă  0
W2 = torch.randn((hidden_dim, 46)) *0.01
b2 = torch.randn(46)*0 
parameters = [C, W1, b1, W2, b2]
for p in parameters:
  p.requires_grad = True
ix = torch.randint(0, Xtr.shape[0], (batch_size,))
Xb, Yb = Xtr[ix], Ytr[ix] 
emb = C[Xb] 
embcat = emb.view(emb.shape[0], -1)
hpreact = embcat @ W1 + b1 
h = torch.tanh(hpreact) 
logits = h @ W2 + b2 
loss = F.cross_entropy(logits, Yb)
  
for p in parameters:
  p.grad = None
loss.backward()
plt.hist(h.view(-1).tolist(),50);
../_images/b85d5e3d1ce32df3dc45c1f53b46dbc362ba9084e4cfc3d929e79d59c2dd1b9a.png
plt.figure(figsize=(20,10))
plt.imshow(h.abs()>0.99,cmap='gray',interpolation='nearest')
<matplotlib.image.AxesImage at 0x7f025c538190>
../_images/84c12afa3ac7227e6d395ebef0559cb95007a9789b0fa4665129d62d0e7d813f.png

Tout va pour le mieux !

Valeurs optimales à l’initialisation#

Ce problĂšme Ă©tant trĂšs important, de nombreuses recherches se sont dirigĂ©es sur ce sujet. Une publication notable est Delving Deep into Rectifiers, qui introduit la Kaiming initialization. Le papier propose des valeurs d’initialisation pour chaque fonction d’activation qui garantissent une distribution centrĂ©e rĂ©duite sur l’ensemble du rĂ©seau.

Cette méthode est implémentée en PyTorch, et les couches que nous allons créer en PyTorch sont directement initialisées de cette maniÚre.

Pourquoi ce cours est dans les bonus alors qu’il semble trùs important ?#

Ce problĂšme est en effet un problĂšme majeur. Cependant, lorsque l’on utilise PyTorch, tout est dĂ©jĂ  initialisĂ© correctement, et il n’est gĂ©nĂ©ralement pas nĂ©cessaire de modifier ces valeurs.

De plus, de nombreuses méthodes ont été proposées pour atténuer ce problÚme, principalement :

  • La batch norm, que nous verrons dans le notebook suivant, qui consiste Ă  normaliser les valeurs avant l’activation tout au long du rĂ©seau.

  • Les connexions rĂ©siduelles, qui permettent de transmettre le gradient dans l’intĂ©gralitĂ© du rĂ©seau sans que celui-ci ne soit trop impactĂ© par les fonctions d’activation.

MalgrĂ© l’importance de ces considĂ©rations, en pratique, il n’est pas forcĂ©ment nĂ©cessaire d’ĂȘtre au courant pour entraĂźner un rĂ©seau de neurones.