Implementing a Diffusion Model

Implementing a Diffusion Model#

In this course, we will step-by-step implement a diffusion model on the MNIST dataset. This course is largely inspired by the minDiffusion project.

from typing import Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision import transforms as T
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
/home/aquilae/anaconda3/envs/dev/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Dataset Preparation#

transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (1.0))])

dataset = MNIST("./../data",train=True,download=True,transform=transform)
print('Taille du dataset :', len(dataset))
print('Taille d\'une image :', dataset[0][0].numpy().shape)

dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=8)
Taille du dataset : 60000
Taille d'une image : (1, 28, 28)

Noise Scheduling#

For a diffusion model, it is necessary to generate noise at each step. We will therefore create noise value ranges between \(\beta_1\) and \(\beta_2\).

def ddpm_schedules(beta1: float, beta2: float, T: int):

  # On vérifie que beta1 et beta2 sont bien dans l'intervalle (0, 1) et que beta1 < beta2
  assert beta1 < beta2 < 1.0, "beta1 et beta2 doivent être dans l'intervalle (0, 1)"

  # On crée un vecteur de taille T+1 allant de beta1 à beta2 qui échantillonne linéairement l'intervalle [beta1, beta2]
  beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
  
  # On calcule toutes les valeurs qui seront nécessaires pour les calculs de l'optimisation
  sqrt_beta_t = torch.sqrt(beta_t)
  alpha_t = 1 - beta_t
  log_alpha_t = torch.log(alpha_t)
  alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

  sqrtab = torch.sqrt(alphabar_t)
  oneover_sqrta = 1 / torch.sqrt(alpha_t)

  sqrtmab = torch.sqrt(1 - alphabar_t)
  mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab

  return {
    "alpha_t": alpha_t,  # \alpha_t
    "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
    "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
    "alphabar_t": alphabar_t,  # \bar{\alpha_t}
    "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
    "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
    "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
  }

Model Architecture#

We will now create our model. Generally, a U-Net architecture is used, but in practice, other architectures can work. For simplicity, we will use a basic convolutional model. Note that normally, the model takes the step \(t\) as input, but in our simplified version, we will not do this.

def conv_bn_relu(in_channels, out_channels,kernel_size=7, stride=1, padding=3):
  return nn.Sequential(
    nn.Conv2d(in_channels, out_channels, kernel_size,stride, padding),
    nn.BatchNorm2d(out_channels),
    nn.LeakyReLU())

class model(nn.Module):

  def __init__(self, channels: int) -> None:
    super(model, self).__init__()
    # Très petit modèle
    self.conv = nn.Sequential(  
      conv_bn_relu(channels, 64),
      conv_bn_relu(64, 128),
      conv_bn_relu(128, 256),
      conv_bn_relu(256, 512),
      conv_bn_relu(512, 256),
      conv_bn_relu(256, 128),
      conv_bn_relu(128, 64),
      nn.Conv2d(64, channels, 3, padding=1),
    )

  def forward(self, x):
    return self.conv(x)

Now that we have our model to go from one step to another, let’s build the overall model:

class DDPM(nn.Module):
    def __init__(self,model: nn.Module,betas: Tuple[float, float],n_T: int,criterion: nn.Module = nn.MSELoss()) -> None:
        super(DDPM, self).__init__()
        self.model = model

        # Permet de stocker les ddpm schedules dans le modèle pour accéder aux valeurs plus facilement
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T
        self.criterion = criterion

    # Etape d'entrainement
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Makes forward diffusion x_t, and tries to guess epsilon value from x_t using eps_model.
        This implements Algorithm 1 in the paper.
        """
        
        # Génère un entier aléatoire entre 1 et n_T pour choisir un t aléatoire
        _ts = torch.randint(1, self.n_T, (x.shape[0],)).to(x.device) 
        # Génère un bruit aléatoire de la même taille que x
        eps = torch.randn_like(x)  # eps ~ N(0, 1)

            # On applique le bruit gaussien à x pour obtenir x_t
        x_t = (self.sqrtab[_ts, None, None, None] * x+ self.sqrtmab[_ts, None, None, None] * eps)  
        # On va essayer de prédire le bruit epsilon à partir de x_t
        pred_eps = self.model(x_t)
        return self.criterion(eps, pred_eps)

    # Génération d'un échantillon
    def sample(self, n_sample: int, size, device) -> torch.Tensor:
        
        # On génère un échantillon aléatoire de taille n_sample à partir d'une distribution normale centrée réduite
        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1)

        # On va appliquer le processus de diffusion inverse pour générer un échantillon (ça prend du temps 
        # car on doit appliquer le processus de diffusion à chaque étape)
        for i in range(self.n_T, 0, -1):
            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
            eps = self.model(x_i)
            x_i = (self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])+ self.sqrt_beta_t[i] * z)

        return x_i

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ddpm = DDPM(model=model(channels=1), betas=(1e-4, 0.02), n_T=1000)
ddpm.to(device);

And that’s it, our simple implementation is complete!

Model Training#

Now let’s move on to training the model. We will significantly reduce the dataset size (only 1000 elements), but training will still be very long. I do not recommend trying this unless you have a very good GPU.

epoch=100
n_T=1000
optimizer = torch.optim.Adam(ddpm.parameters(), lr=2e-4)
generation=[]
for i in range(0,epoch+1):
  ddpm.train()
  loss_ema = None
  for x, _ in dataloader:
    optimizer.zero_grad()
    x = x.to(device)
    loss = ddpm(x)
    loss.backward()
    if loss_ema is None:
      loss_ema = loss.item()
    else:
      loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
    optimizer.step()
  print(f"epoch {i}, loss: {loss_ema:.4f}")
  if i % 10 == 0:
    ddpm.eval()
    with torch.no_grad():
      print('ici')
      xh = ddpm.sample(4, (1, 28, 28), device)
      grid = make_grid(xh, nrow=4)
      generation.append(grid)      
for grid in generation:
  grid_image = grid.permute(1, 2, 0).cpu().numpy()
  # Afficher l'image
  plt.imshow(grid_image)
  plt.axis('off')  # Pour masquer les axes
  plt.show()