Activaciones e inicializaciones#

En este curso, retomaremos el modelo Fully Connected presentado en el curso 5 sobre PNL (Procesamiento del Lenguaje Natural). Analizaremos el comportamiento de las activaciones a lo largo de la red durante su inicialización. Este contenido está inspirado en la clase de Andrej Karpathy titulada Building makemore Part 3: Activations & Gradients, BatchNorm.

Las redes neuronales ofrecen varias ventajas:

  • Son muy flexibles y pueden resolver una amplia variedad de problemas.

  • Su implementación es relativamente sencilla.

Sin embargo, su optimización suele ser compleja, especialmente cuando se trata de redes profundas.

Recuperación de código#

Estamos usando el código del cuaderno 3 del curso 5 sobre PNL.

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
words = open('../05_NLP/prenoms.txt', 'r').read().splitlines()
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

Por razones pedagógicas, no utilizaremos el dataset ni el dataloader de PyTorch. En su lugar, evaluaremos la pérdida (loss) al inicio del entrenamiento, tras procesar el primer batch. El procedimiento es similar al estándar, con una diferencia clave: en cada iteración, seleccionamos un batch aleatorio en lugar de recorrer todo el dataset en cada epoch.

block_size = 3 # Contexte

def build_dataset(words):  
  X, Y = [], []
  
  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix] 

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%
torch.Size([180834, 3]) torch.Size([180834])
torch.Size([22852, 3]) torch.Size([22852])
torch.Size([22639, 3]) torch.Size([22639])
embed_dim=10 # Dimension de l'embedding de C
hidden_dim=200 # Dimension de la couche cachée

C = torch.randn((46, embed_dim))
W1 = torch.randn((block_size*embed_dim, hidden_dim))
b1 = torch.randn(hidden_dim)
W2 = torch.randn((hidden_dim, 46))
b2 = torch.randn(46)
parameters = [C, W1, b1, W2, b2]
for p in parameters:
  p.requires_grad = True
max_steps = 200000
batch_size = 32
lossi = []

for i in range(max_steps):
  # Permet de construire un mini-batch
  ix = torch.randint(0, Xtr.shape[0], (batch_size,))
  
  # Forward
  Xb, Yb = Xtr[ix], Ytr[ix] 
  emb = C[Xb] 
  embcat = emb.view(emb.shape[0], -1)
  hpreact = embcat @ W1 + b1 

  h = torch.tanh(hpreact) 
  logits = h @ W2 + b2 
  loss = F.cross_entropy(logits, Yb)
  
  # Retropropagation
  for p in parameters:
    p.grad = None
  
  loss.backward()
  # Mise à jour des paramètres
  lr = 0.1 if i < 100000 else 0.01 # On descend le learning rate d'un facteur 10 après 100000 itérations
  for p in parameters:
    p.data += -lr * p.grad

  if i % 10000 == 0:
    print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')
  lossi.append(loss.log10().item())
      0/ 200000: 21.9772
  10000/ 200000: 2.9991
  20000/ 200000: 2.5258
  30000/ 200000: 1.9657
  40000/ 200000: 2.4326
  50000/ 200000: 1.7670
  60000/ 200000: 2.1324
  70000/ 200000: 2.4160
  80000/ 200000: 2.2237
  90000/ 200000: 2.3905
 100000/ 200000: 1.9304
 110000/ 200000: 2.1710
 120000/ 200000: 2.3444
 130000/ 200000: 2.0970
 140000/ 200000: 1.8623
 150000/ 200000: 1.9792
 160000/ 200000: 2.4602
 170000/ 200000: 2.0968
 180000/ 200000: 2.0466
 190000/ 200000: 2.3746
plt.plot(lossi)
[<matplotlib.lines.Line2D at 0x7f028467b990>]
../_images/f9f24259f59ae48dd736be57df89e1b176e2c4aa450e77192ec2f594d4986101.png

Hay mucho “ruido” en los resultados, ya que calculamos la pérdida (loss) en batches pequeños en comparación con el conjunto completo de datos de entrenamiento.

Pérdida anormalmente alta en la inicialización#

El entrenamiento transcurre correctamente. No obstante, observamos algo anómalo: la pérdida inicial (loss) es excesivamente alta. Lo esperado sería un valor equivalente a un escenario donde cada letra tenga una probabilidad uniforme de aparición (es decir, \(p = \frac{1}{46}\)).

En este caso, el negative log likelihood sería: \(-ln\left(\frac{1}{46}\right) = 3.83\)

Por lo tanto, sería lógico obtener un valor cercano a este en el primer cálculo de loss.

Pequeño ejemplo que ilustra el problema#

Para entender qué está ocurriendo, utilicemos un ejemplo sencillo y analicemos los valores de loss según la inicialización. Supongamos que todos los pesos en logits se inicializan en 0. En este caso, obtendríamos probabilidades uniformes.

logits=torch.tensor([0.0,0.0,0.0,0.0])
probs=torch.softmax(logits,dim=0)
loss= -probs[1].log()
probs,loss
(tensor([0.2500, 0.2500, 0.2500, 0.2500]), tensor(1.3863))

Sin embargo, no es recomendable inicializar los pesos de una red neuronal en 0. En su lugar, hemos utilizado una inicialización aleatoria basada en una distribución gaussiana estándar (media = 0, desviación estándar = 1).

logits=torch.randn(4)
probs=torch.softmax(logits,dim=0)
loss= -probs[1].log()
probs,loss
(tensor([0.3143, 0.0607, 0.3071, 0.3178]), tensor(2.8012))

El problema es evidente: el azar de la distribución gaussiana desequilibra los valores hacia un extremo u otro (puedes ejecutar el código anterior varias veces para verificarlo).

¿Cómo solucionarlo? Basta con multiplicar nuestro vector logit por un valor pequeño (ej. 0.01) para reducir la magnitud inicial de los pesos y lograr que la salida de softmax sea más uniforme.

logits=torch.randn(4)*0.01
probs=torch.softmax(logits,dim=0)
loss= -probs[1].log()
probs,loss
(tensor([0.2489, 0.2523, 0.2495, 0.2493]), tensor(1.3772))

Obtenemos un valor de loss muy cercano al esperado para probabilidades uniformes.

Nota: En cambio, sí es válido inicializar los sesgos (bias) en 0, ya que no tiene sentido asignarles un valor positivo o negativo al inicio.

Entrenamiento con ajuste de inicialización#

Retomemos el código anterior, pero aplicando los nuevos valores de inicialización.

C = torch.randn((46, embed_dim))
W1 = torch.randn((block_size*embed_dim, hidden_dim))*0.01 # On initialise les poids à une petite valeur
b1 = torch.randn(hidden_dim) *0 # On initialise les biais à 0
W2 = torch.randn((hidden_dim, 46))*0.01
b2 = torch.randn(46)*0 
parameters = [C, W1, b1, W2, b2]
for p in parameters:
  p.requires_grad = True
lossi = []

for i in range(max_steps):
  ix = torch.randint(0, Xtr.shape[0], (batch_size,))
  Xb, Yb = Xtr[ix], Ytr[ix] 
  emb = C[Xb] 
  embcat = emb.view(emb.shape[0], -1)
  hpreact = embcat @ W1 + b1 
  h = torch.tanh(hpreact) 
  logits = h @ W2 + b2 
  loss = F.cross_entropy(logits, Yb)
  
  for p in parameters:
    p.grad = None
  loss.backward()
  lr = 0.1 if i < 100000 else 0.01 
  for p in parameters:
    p.data += -lr * p.grad
  if i % 10000 == 0:
    print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')
  lossi.append(loss.log10().item())
      0/ 200000: 3.8304
  10000/ 200000: 2.4283
  20000/ 200000: 2.0651
  30000/ 200000: 2.1124
  40000/ 200000: 2.3158
  50000/ 200000: 2.2752
  60000/ 200000: 2.1887
  70000/ 200000: 2.1783
  80000/ 200000: 1.8120
  90000/ 200000: 2.3178
 100000/ 200000: 2.0973
 110000/ 200000: 1.8992
 120000/ 200000: 1.6917
 130000/ 200000: 2.2747
 140000/ 200000: 1.8054
 150000/ 200000: 2.3569
 160000/ 200000: 2.4231
 170000/ 200000: 2.0711
 180000/ 200000: 2.1379
 190000/ 200000: 1.8419
plt.plot(lossi)
[<matplotlib.lines.Line2D at 0x7f0278310110>]
../_images/274b51e7627a1763fa99e1b031ad6f827df67595fbd791710bcfa568811a447d.png

Ahora obtenemos una curva de pérdida (loss) que no parte de un valor anómalo, lo que acelera significativamente el proceso de optimización.

Otro problema#

Podría pensarse que un loss elevado no es necesariamente un problema. Sin embargo, una mala inicialización de los pesos puede generar otros inconvenientes.

Analicemos la primera iteración del entrenamiento con valores inicializados sin el factor 0.01.

C = torch.randn((46, embed_dim))
W1 = torch.randn((block_size*embed_dim, hidden_dim)) 
b1 = torch.randn(hidden_dim) 
W2 = torch.randn((hidden_dim, 46))
b2 = torch.randn(46)
parameters = [C, W1, b1, W2, b2]
for p in parameters:
  p.requires_grad = True
ix = torch.randint(0, Xtr.shape[0], (batch_size,))
Xb, Yb = Xtr[ix], Ytr[ix] 
emb = C[Xb] 
embcat = emb.view(emb.shape[0], -1)
hpreact = embcat @ W1 + b1 
h = torch.tanh(hpreact) 
logits = h @ W2 + b2 
loss = F.cross_entropy(logits, Yb)
  
for p in parameters:
  p.grad = None
loss.backward()

Observamos el histograma de valores después de la función de activación tanh.

plt.hist(h.view(-1).tolist(),50);
../_images/26b88fddfbb66837e96c8c3d11a891c516b52dd14c4e49629d1b85562c160f6a.png

Observamos que la mayoría de los valores se concentran alrededor de 1 o -1.

¿Por qué es esto un problema? Durante el cálculo del gradiente, al aplicar la regla de la cadena, multiplicamos los gradientes de las distintas etapas. La derivada de la función tanh es: \(tanh'(t) = 1 - t^2\)

Si los valores de \(t\) están cerca de 1 o -1, el gradiente será extremadamente pequeño (nunca cero, pues es una asíntota). Esto implica que el gradiente no se propaga correctamente, dificultando la optimización al inicio del entrenamiento.

plt.figure(figsize=(20,10))
plt.imshow(h.abs()>0.99,cmap='gray',interpolation='nearest')
<matplotlib.image.AxesImage at 0x7f02780ae550>
../_images/f1d34dca2ef3cb2f2f05716107320504d6500d05b81000257cf6dd8446104fa3.png

Cada punto blanco representa un neurona cuyo gradiente es aproximadamente 0.

Neurona muerta: Si una de estas columnas estuviera completamente en blanco, significaría que la neurona no se activa para ningún elemento del batch. Esto implica que se trata de una neurona inútil, sin impacto en el resultado y que no puede optimizarse con los valores actuales del batch.

Notas importantes:

  • Este comportamiento no es exclusivo de tanh: funciones como sigmoid y ReLU también pueden presentarlo.

  • El problema no impidió entrenar correctamente nuestra red, ya que se trata de un modelo pequeño. En redes más profundas, es un problema grave, y se recomienda verificar las activaciones en las distintas capas.

  • Las neuronas muertas pueden aparecer durante la inicialización o durante el entrenamiento (por ejemplo, si el learning rate es demasiado alto).

¿Como solucionar este problema?#

Afortunadamente, este problema puede resolverse de la misma manera que el del loss demasiado elevado. Para confirmarlo, analicemos los valores de las activaciones y las neuronas inactivas al inicio, aplicando los nuevos valores de inicialización.

C = torch.randn((46, embed_dim))
W1 = torch.randn((block_size*embed_dim, hidden_dim)) *0.01# On initialise les poids à une petite valeur
b1 = torch.randn(hidden_dim) *0 # On initialise les biais à 0
W2 = torch.randn((hidden_dim, 46)) *0.01
b2 = torch.randn(46)*0 
parameters = [C, W1, b1, W2, b2]
for p in parameters:
  p.requires_grad = True
ix = torch.randint(0, Xtr.shape[0], (batch_size,))
Xb, Yb = Xtr[ix], Ytr[ix] 
emb = C[Xb] 
embcat = emb.view(emb.shape[0], -1)
hpreact = embcat @ W1 + b1 
h = torch.tanh(hpreact) 
logits = h @ W2 + b2 
loss = F.cross_entropy(logits, Yb)
  
for p in parameters:
  p.grad = None
loss.backward()
plt.hist(h.view(-1).tolist(),50);
../_images/b85d5e3d1ce32df3dc45c1f53b46dbc362ba9084e4cfc3d929e79d59c2dd1b9a.png
plt.figure(figsize=(20,10))
plt.imshow(h.abs()>0.99,cmap='gray',interpolation='nearest')
<matplotlib.image.AxesImage at 0x7f025c538190>
../_images/84c12afa3ac7227e6d395ebef0559cb95007a9789b0fa4665129d62d0e7d813f.png

¡Todo funciona correctamente ahora!

Valores óptimos en la inicialización#

Dada la importancia de este problema, numerosas investigaciones se han enfocado en él. Una publicación destacada es Delving Deep into Rectifiers, que introduce la inicialización de Kaiming. Este trabajo propone valores de inicialización específicos para cada función de activación, garantizando una distribución centrada y reducida en toda la red.

Este método está implementado en PyTorch, y las capas que creemos en PyTorch se inicializan automáticamente de esta manera.

¿Por qué este curso está en los bonos cuando parece muy importante?#

Este problema es, efectivamente, un desafío clave. Sin embargo, al usar PyTorch, todos los parámetros ya se inicializan correctamente, por lo que no suele ser necesario modificarlos manualmente.

Además, se han propuesto diversas técnicas para mitigar este problema, entre las que destacan:

  • Batch Normalization (Batch Norm): Normaliza los valores antes de la activación en cada capa de la red (se abordará en el siguiente notebook).

  • Conexiones residuales (Residual Connections): Permiten propagar el gradiente a través de toda la red sin que este se vea afectado en exceso por las funciones de activación.

Aunque estas consideraciones son fundamentales, en la práctica no es estrictamente necesario dominarlas para entrenar una red neuronal, gracias a las optimizaciones ya integradas en frameworks como PyTorch.