Implémentation d’un modèle de diffusion#

Dans ce cours, nous allons implémenter étape par étape un modèle de diffusion sur le dataset MNIST. Ce cours s’inspire largement du projet minDiffusion.

from typing import Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision import transforms as T
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
/home/aquilae/anaconda3/envs/dev/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Préparation du dataset#

transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (1.0))])

dataset = MNIST("./../data",train=True,download=True,transform=transform)
print('Taille du dataset :', len(dataset))
print('Taille d\'une image :', dataset[0][0].numpy().shape)

dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=8)
Taille du dataset : 60000
Taille d'une image : (1, 28, 28)

Planification du bruit#

Pour un modèle de diffusion, il est nécessaire de générer du bruit à chaque étape. Nous allons donc créer des plages de valeurs de bruit comprises entre \(\beta_1\) et \(\beta_2\).

def ddpm_schedules(beta1: float, beta2: float, T: int):

  # On vérifie que beta1 et beta2 sont bien dans l'intervalle (0, 1) et que beta1 < beta2
  assert beta1 < beta2 < 1.0, "beta1 et beta2 doivent être dans l'intervalle (0, 1)"

  # On crée un vecteur de taille T+1 allant de beta1 à beta2 qui échantillonne linéairement l'intervalle [beta1, beta2]
  beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
  
  # On calcule toutes les valeurs qui seront nécessaires pour les calculs de l'optimisation
  sqrt_beta_t = torch.sqrt(beta_t)
  alpha_t = 1 - beta_t
  log_alpha_t = torch.log(alpha_t)
  alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

  sqrtab = torch.sqrt(alphabar_t)
  oneover_sqrta = 1 / torch.sqrt(alpha_t)

  sqrtmab = torch.sqrt(1 - alphabar_t)
  mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab

  return {
    "alpha_t": alpha_t,  # \alpha_t
    "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
    "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
    "alphabar_t": alphabar_t,  # \bar{\alpha_t}
    "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
    "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
    "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
  }

Architecture du modèle#

Nous allons maintenant créer notre modèle. En général, on utilise une architecture U-Net, mais en pratique, d’autres architectures peuvent convenir. Pour simplifier, nous allons utiliser un modèle convolutif basique. Notez que normalement, le modèle prend l’étape \(t\) en entrée, mais dans notre version simplifiée, nous ne le ferons pas.

def conv_bn_relu(in_channels, out_channels,kernel_size=7, stride=1, padding=3):
  return nn.Sequential(
    nn.Conv2d(in_channels, out_channels, kernel_size,stride, padding),
    nn.BatchNorm2d(out_channels),
    nn.LeakyReLU())

class model(nn.Module):

  def __init__(self, channels: int) -> None:
    super(model, self).__init__()
    # Très petit modèle
    self.conv = nn.Sequential(  
      conv_bn_relu(channels, 64),
      conv_bn_relu(64, 128),
      conv_bn_relu(128, 256),
      conv_bn_relu(256, 512),
      conv_bn_relu(512, 256),
      conv_bn_relu(256, 128),
      conv_bn_relu(128, 64),
      nn.Conv2d(64, channels, 3, padding=1),
    )

  def forward(self, x):
    return self.conv(x)

Maintenant que nous avons notre modèle pour passer d’une étape à une autre, construisons le modèle global :

class DDPM(nn.Module):
    def __init__(self,model: nn.Module,betas: Tuple[float, float],n_T: int,criterion: nn.Module = nn.MSELoss()) -> None:
        super(DDPM, self).__init__()
        self.model = model

        # Permet de stocker les ddpm schedules dans le modèle pour accéder aux valeurs plus facilement
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T
        self.criterion = criterion

    # Etape d'entrainement
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Makes forward diffusion x_t, and tries to guess epsilon value from x_t using eps_model.
        This implements Algorithm 1 in the paper.
        """
        
        # Génère un entier aléatoire entre 1 et n_T pour choisir un t aléatoire
        _ts = torch.randint(1, self.n_T, (x.shape[0],)).to(x.device) 
        # Génère un bruit aléatoire de la même taille que x
        eps = torch.randn_like(x)  # eps ~ N(0, 1)

            # On applique le bruit gaussien à x pour obtenir x_t
        x_t = (self.sqrtab[_ts, None, None, None] * x+ self.sqrtmab[_ts, None, None, None] * eps)  
        # On va essayer de prédire le bruit epsilon à partir de x_t
        pred_eps = self.model(x_t)
        return self.criterion(eps, pred_eps)

    # Génération d'un échantillon
    def sample(self, n_sample: int, size, device) -> torch.Tensor:
        
        # On génère un échantillon aléatoire de taille n_sample à partir d'une distribution normale centrée réduite
        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1)

        # On va appliquer le processus de diffusion inverse pour générer un échantillon (ça prend du temps 
        # car on doit appliquer le processus de diffusion à chaque étape)
        for i in range(self.n_T, 0, -1):
            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
            eps = self.model(x_i)
            x_i = (self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])+ self.sqrt_beta_t[i] * z)

        return x_i

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ddpm = DDPM(model=model(channels=1), betas=(1e-4, 0.02), n_T=1000)
ddpm.to(device);

Et voilà, notre implémentation simple est terminée !

Entraînement du modèle#

Passons maintenant à l’entraînement du modèle. Nous allons réduire considérablement la taille du dataset (seulement 1000 éléments), mais l’entraînement restera très long. Je ne vous conseille pas d’essayer sauf si vous disposez d’un très bon GPU.

epoch=100
n_T=1000
optimizer = torch.optim.Adam(ddpm.parameters(), lr=2e-4)
generation=[]
for i in range(0,epoch+1):
  ddpm.train()
  loss_ema = None
  for x, _ in dataloader:
    optimizer.zero_grad()
    x = x.to(device)
    loss = ddpm(x)
    loss.backward()
    if loss_ema is None:
      loss_ema = loss.item()
    else:
      loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
    optimizer.step()
  print(f"epoch {i}, loss: {loss_ema:.4f}")
  if i % 10 == 0:
    ddpm.eval()
    with torch.no_grad():
      print('ici')
      xh = ddpm.sample(4, (1, 28, 28), device)
      grid = make_grid(xh, nrow=4)
      generation.append(grid)      
for grid in generation:
  grid_image = grid.permute(1, 2, 0).cpu().numpy()
  # Afficher l'image
  plt.imshow(grid_image)
  plt.axis('off')  # Pour masquer les axes
  plt.show()