Introduction aux autoencodeurs#

Apprentissage supervisé et non supervisé#

Apprentissage supervisé#

Dans les cours prĂ©cĂ©dents, nous avons uniquement traitĂ© des cas d’apprentissage supervisĂ©. En gros, il s’agit des situations oĂč les donnĂ©es d’entraĂźnement contiennent Ă  la fois une entrĂ©e x et une sortie y. Le modĂšle doit alors prendre x en entrĂ©e et prĂ©dire y. Par exemple, pour MNIST, on avait une image x et un label y reprĂ©sentant un chiffre entre 0 et 9. En segmentation, on utilisait une image x et un masque y en sortie.

Apprentissage non supervisé#

En apprentissage non supervisĂ©, les donnĂ©es ne sont pas Ă©tiquetĂ©es, ce qui signifie qu’on a seulement x sans y. Dans ce cas, on ne peut pas prĂ©dire une valeur prĂ©cise, mais on peut entraĂźner un modĂšle Ă  regrouper des Ă©lĂ©ments similaires (on parle de clustering). Dans ce cours, on se concentrera sur la dĂ©tection d’anomalies non supervisĂ©e. L’idĂ©e est d’entraĂźner un modĂšle sur un certain type de donnĂ©es, puis de l’utiliser pour dĂ©tecter des Ă©lĂ©ments qui diffĂšrent du jeu d’entraĂźnement.

Autoencodeur#

Architecture#

Le modĂšle de base pour ce type de tĂąches s’appelle “autoencodeur”. Son architecture ressemble Ă  celle du U-Net que nous avons vu prĂ©cĂ©demment. Voici l’architecture classique d’un autoencodeur :

Comme tu peux le voir, il a une forme de “sablier”. L’idĂ©e de l’autoencodeur est de crĂ©er une reprĂ©sentation compressĂ©e des donnĂ©es d’entrĂ©e et de les reconstruire Ă  partir de cette reprĂ©sentation. D’ailleurs, ce modĂšle peut aussi servir Ă  compresser des donnĂ©es.

Utilisation pour la dĂ©tection d’anomalies non supervisĂ©e#

Pour la dĂ©tection d’anomalies non supervisĂ©e, prenons un exemple. On entraĂźne l’autoencodeur Ă  reconstruire des images du chiffre 5. Une fois entraĂźnĂ©, il reconstruira parfaitement les images de 5. Si on veut dĂ©tecter si une image est un 5 ou un autre chiffre, il suffit de la donner Ă  l’autoencodeur. En analysant la qualitĂ© de la reconstruction (\(image_{base} - image_{recons}\)), on peut dĂ©terminer s’il s’agit d’un 5 ou non. L’image suivante illustre ce principe :

Application pratique sur MNIST#

Pour illustrer ce qui a été décrit, nous allons entraßner un autoencodeur pour reconstruire les 5 avec PyTorch.

import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision import datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Pour la reproducibilité
np.random.seed(1337)
random.seed(1337)

CrĂ©ation des datasets d’entraĂźnement et de test#

transform=T.ToTensor() # Pour convertir les éléments en tensor torch directement
dataset = datasets.MNIST(root='./../data', train=True, download=True,transform=transform)
test_dataset = datasets.MNIST(root='./../data', train=False,transform=transform)

Nous avons rĂ©cupĂ©rĂ© nos datasets d’entraĂźnement/validation et de test. On veut ne garder que les 5 dans le dataset d’entraĂźnement. Pour cela, supprimons les Ă©lĂ©ments qui ne contiennent pas le chiffre 5.

# On récupere les indices des images de 5
indices = [i for i, label in enumerate(dataset.targets) if label == 5]
# On créer un nouveau dataset avec uniquement les 5
filtered_dataset = torch.utils.data.Subset(dataset, indices)

On peut visualiser quelques images pour vĂ©rifier qu’on a bien que des 5.

fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i in range(5):
  image, label = filtered_dataset[i]
  image = image.squeeze().numpy() 
  axes[i].imshow(image, cmap='gray')
  axes[i].set_title(f'Label: {label}')
  axes[i].axis('off')
plt.show()
../_images/5843f2b5282a70bf40e1c6d6ae1781a32c369e083f29e5c6dc33ae9335078ee8.png

Divisons maintenant le dataset en parties d’entraĂźnement et de validation, puis crĂ©ons nos dataloaders.

train_dataset, validation_dataset=torch.utils.data.random_split(filtered_dataset, [0.8,0.2])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader= DataLoader(validation_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Création du modÚle autoencodeur#

Pour le dataset MNIST, une architecture peu profonde suffit pour obtenir de bons résultats.

class ae(nn.Module):
  def __init__(self, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)

    self.encoder = nn.Sequential( # Sequential permet de groupe une série de transformation
      nn.Linear(28 * 28, 512), 
      nn.ReLU(),
      nn.Linear(512, 256),
      nn.ReLU(),
      nn.Linear(256, 128),
      nn.ReLU(),
    )
    self.decoder = nn.Sequential(
      nn.Linear(128, 256),
      nn.ReLU(),
      nn.Linear(256, 512),
      nn.ReLU(),
      nn.Linear(512, 28 * 28),
      nn.Sigmoid()
    )
  
  def forward(self,x):
    x=x.view(-1,28*28) 
    x = self.encoder(x)
    x = self.decoder(x)
    recons=x.view(-1,28,28)
    return recons
model = ae()
print(model)
print("Nombre de paramĂštres", sum(p.numel() for p in model.parameters()))
ae(
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=128, bias=True)
    (5): ReLU()
  )
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=784, bias=True)
    (5): Sigmoid()
  )
)
Nombre de paramĂštres 1132944

EntraĂźnement du modĂšle#

Pour la fonction de perte, nous utilisons MSELoss, qui correspond Ă  l’erreur quadratique moyenne dĂ©finie par : \(\text{MSE} = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2\) oĂč \(N\) est le nombre total de pixels dans l’image, \(y_i\) est la valeur du pixel \(i\) dans l’image originale, et \(\hat{y}_i\) est la valeur du pixel \(i\) dans l’image reconstruite. C’est une fonction classique pour Ă©valuer la qualitĂ© d’une reconstruction.

criterion = nn.MSELoss()
epochs=10
learning_rate=0.001
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)
for i in range(epochs):
  loss_train=0
  for images, _ in train_loader:
    recons=model(images)
    loss=criterion(recons,images)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    loss_train+=loss   
  if i % 1 == 0:
    print(f"step {i} train loss {loss_train/len(train_loader)}")
  loss_val=0    
  for images, _ in val_loader:
    with torch.no_grad():
      recons=model(images)
      loss=criterion(recons,images)
      loss_val+=loss 
  if i % 1 == 0:
    print(f"step {i} val loss {loss_val/len(val_loader)}")
step 0 train loss 0.08228749781847
step 0 val loss 0.06261523813009262
step 1 train loss 0.06122465804219246
step 1 val loss 0.06214689463376999
step 2 train loss 0.06105153635144234
step 2 val loss 0.06189680099487305
step 3 train loss 0.06086035445332527
step 3 val loss 0.06180128455162048
step 4 train loss 0.0608210563659668
step 4 val loss 0.06169722229242325
step 5 train loss 0.06080913543701172
step 5 val loss 0.061976321041584015
step 6 train loss 0.060783520340919495
step 6 val loss 0.06190618872642517
step 7 train loss 0.06072703003883362
step 7 val loss 0.06161761283874512
step 8 train loss 0.06068740040063858
step 8 val loss 0.061624933034181595
step 9 train loss 0.060728199779987335
step 9 val loss 0.061608292162418365

Regardons maintenant la reconstruction des images du dataset de test.

images,_=next(iter(test_loader))

#Isolons un élément 
image=images[0].unsqueeze(0)
with torch.no_grad():
  recons=model(image)
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

# Image d'origine
axs[0].imshow(image[0].squeeze().cpu().numpy(), cmap='gray')
axs[0].set_title('Image d\'origine')
axs[0].axis('off')

# Image reconstruite
axs[1].imshow(recons[0].squeeze().cpu().numpy(), cmap='gray')
axs[1].set_title('Image reconstruite')
axs[1].axis('off')
plt.show()
print("difference : ", criterion(image,recons).item())
../_images/92c5423a04ea1f5b3208faa2ae8aca313f88d7f17512079a1aacdcc9e9c54bbd.png
difference :  0.0687035545706749

On remarque que la reconstruction du 7 est trĂšs mauvaise, ce qui permet de dĂ©duire qu’il s’agit d’une anomalie.