Implémentation d’un modèle de diffusion#
Dans ce cours, nous allons implémenter étape par étape un modèle de diffusion sur le dataset MNIST. Ce cours s’inspire largement du projet minDiffusion.
from typing import Tuple
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms as T
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
/home/aquilae/anaconda3/envs/dev/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Préparation du dataset#
transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (1.0))])
dataset = MNIST("./../data",train=True,download=True,transform=transform)
print('Taille du dataset :', len(dataset))
print('Taille d\'une image :', dataset[0][0].numpy().shape)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=8)
Taille du dataset : 60000
Taille d'une image : (1, 28, 28)
Planification du bruit#
Pour un modèle de diffusion, il est nécessaire de générer du bruit à chaque étape. Nous allons donc créer des plages de valeurs de bruit comprises entre \(\beta_1\) et \(\beta_2\).
def ddpm_schedules(beta1: float, beta2: float, T: int):
# On vérifie que beta1 et beta2 sont bien dans l'intervalle (0, 1) et que beta1 < beta2
assert beta1 < beta2 < 1.0, "beta1 et beta2 doivent être dans l'intervalle (0, 1)"
# On crée un vecteur de taille T+1 allant de beta1 à beta2 qui échantillonne linéairement l'intervalle [beta1, beta2]
beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
# On calcule toutes les valeurs qui seront nécessaires pour les calculs de l'optimisation
sqrt_beta_t = torch.sqrt(beta_t)
alpha_t = 1 - beta_t
log_alpha_t = torch.log(alpha_t)
alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()
sqrtab = torch.sqrt(alphabar_t)
oneover_sqrta = 1 / torch.sqrt(alpha_t)
sqrtmab = torch.sqrt(1 - alphabar_t)
mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab
return {
"alpha_t": alpha_t, # \alpha_t
"oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t}
"sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t}
"alphabar_t": alphabar_t, # \bar{\alpha_t}
"sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}}
"sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}}
"mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
}
Architecture du modèle#
Nous allons maintenant créer notre modèle. En général, on utilise une architecture U-Net, mais en pratique, d’autres architectures peuvent convenir. Pour simplifier, nous allons utiliser un modèle convolutif basique. Notez que normalement, le modèle prend l’étape \(t\) en entrée, mais dans notre version simplifiée, nous ne le ferons pas.
def conv_bn_relu(in_channels, out_channels,kernel_size=7, stride=1, padding=3):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size,stride, padding),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU())
class model(nn.Module):
def __init__(self, channels: int) -> None:
super(model, self).__init__()
# Très petit modèle
self.conv = nn.Sequential(
conv_bn_relu(channels, 64),
conv_bn_relu(64, 128),
conv_bn_relu(128, 256),
conv_bn_relu(256, 512),
conv_bn_relu(512, 256),
conv_bn_relu(256, 128),
conv_bn_relu(128, 64),
nn.Conv2d(64, channels, 3, padding=1),
)
def forward(self, x):
return self.conv(x)
Maintenant que nous avons notre modèle pour passer d’une étape à une autre, construisons le modèle global :
class DDPM(nn.Module):
def __init__(self,model: nn.Module,betas: Tuple[float, float],n_T: int,criterion: nn.Module = nn.MSELoss()) -> None:
super(DDPM, self).__init__()
self.model = model
# Permet de stocker les ddpm schedules dans le modèle pour accéder aux valeurs plus facilement
for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
self.register_buffer(k, v)
self.n_T = n_T
self.criterion = criterion
# Etape d'entrainement
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Makes forward diffusion x_t, and tries to guess epsilon value from x_t using eps_model.
This implements Algorithm 1 in the paper.
"""
# Génère un entier aléatoire entre 1 et n_T pour choisir un t aléatoire
_ts = torch.randint(1, self.n_T, (x.shape[0],)).to(x.device)
# Génère un bruit aléatoire de la même taille que x
eps = torch.randn_like(x) # eps ~ N(0, 1)
# On applique le bruit gaussien à x pour obtenir x_t
x_t = (self.sqrtab[_ts, None, None, None] * x+ self.sqrtmab[_ts, None, None, None] * eps)
# On va essayer de prédire le bruit epsilon à partir de x_t
pred_eps = self.model(x_t)
return self.criterion(eps, pred_eps)
# Génération d'un échantillon
def sample(self, n_sample: int, size, device) -> torch.Tensor:
# On génère un échantillon aléatoire de taille n_sample à partir d'une distribution normale centrée réduite
x_i = torch.randn(n_sample, *size).to(device) # x_T ~ N(0, 1)
# On va appliquer le processus de diffusion inverse pour générer un échantillon (ça prend du temps
# car on doit appliquer le processus de diffusion à chaque étape)
for i in range(self.n_T, 0, -1):
z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
eps = self.model(x_i)
x_i = (self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])+ self.sqrt_beta_t[i] * z)
return x_i
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ddpm = DDPM(model=model(channels=1), betas=(1e-4, 0.02), n_T=1000)
ddpm.to(device);
Et voilà, notre implémentation simple est terminée !
Entraînement du modèle#
Passons maintenant à l’entraînement du modèle. Nous allons réduire considérablement la taille du dataset (seulement 1000 éléments), mais l’entraînement restera très long. Je ne vous conseille pas d’essayer sauf si vous disposez d’un très bon GPU.
epoch=100
n_T=1000
optimizer = torch.optim.Adam(ddpm.parameters(), lr=2e-4)
generation=[]
for i in range(0,epoch+1):
ddpm.train()
loss_ema = None
for x, _ in dataloader:
optimizer.zero_grad()
x = x.to(device)
loss = ddpm(x)
loss.backward()
if loss_ema is None:
loss_ema = loss.item()
else:
loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
optimizer.step()
print(f"epoch {i}, loss: {loss_ema:.4f}")
if i % 10 == 0:
ddpm.eval()
with torch.no_grad():
print('ici')
xh = ddpm.sample(4, (1, 28, 28), device)
grid = make_grid(xh, nrow=4)
generation.append(grid)
for grid in generation:
grid_image = grid.permute(1, 2, 0).cpu().numpy()
# Afficher l'image
plt.imshow(grid_image)
plt.axis('off') # Pour masquer les axes
plt.show()