Implementation Diffusion Model¶
Dans ce cours, nous allons implémenter pas à pas un modèle de diffusion sur le dataset MNIST.
Le cours est grandement inspiré du github minDiffusion.
from typing import Tuple
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms as T
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
/home/aquilae/anaconda3/envs/dev/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Dataset¶
transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (1.0))])
dataset = MNIST("./../data",train=True,download=True,transform=transform)
print('Taille du dataset :', len(dataset))
print('Taille d\'une image :', dataset[0][0].numpy().shape)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=8)
Taille du dataset : 60000 Taille d'une image : (1, 28, 28)
Noise scheduling¶
Pour un modèle de diffusion, il faut générer du bruit pour chaque étape de diffusion. Pour cela, on va créer nos plages de valeurs de bruit comprises entre $\beta_1$ et $\beta_2$.
def ddpm_schedules(beta1: float, beta2: float, T: int):
# On vérifie que beta1 et beta2 sont bien dans l'intervalle (0, 1) et que beta1 < beta2
assert beta1 < beta2 < 1.0, "beta1 et beta2 doivent être dans l'intervalle (0, 1)"
# On crée un vecteur de taille T+1 allant de beta1 à beta2 qui échantillonne linéairement l'intervalle [beta1, beta2]
beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
# On calcule toutes les valeurs qui seront nécessaires pour les calculs de l'optimisation
sqrt_beta_t = torch.sqrt(beta_t)
alpha_t = 1 - beta_t
log_alpha_t = torch.log(alpha_t)
alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()
sqrtab = torch.sqrt(alphabar_t)
oneover_sqrta = 1 / torch.sqrt(alpha_t)
sqrtmab = torch.sqrt(1 - alphabar_t)
mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab
return {
"alpha_t": alpha_t, # \alpha_t
"oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t}
"sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t}
"alphabar_t": alphabar_t, # \bar{\alpha_t}
"sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}}
"sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}}
"mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
}
Modèle¶
On peut maintenant créer notre modèle ! En général, on prend une architecture U-Net mais en pratique, on peut prendre un peu n'importe quoi. Pour plus de simplicité, on va prendre un modèle convolutif tout bête. Aussi, normalement le modèle prend l'étape $t$ en entrée mais dans notre modèle simplifié, nous n'allons pas le faire.
def conv_bn_relu(in_channels, out_channels,kernel_size=7, stride=1, padding=3):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size,stride, padding),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU())
class model(nn.Module):
def __init__(self, channels: int) -> None:
super(model, self).__init__()
# Très petit modèle
self.conv = nn.Sequential(
conv_bn_relu(channels, 64),
conv_bn_relu(64, 128),
conv_bn_relu(128, 256),
conv_bn_relu(256, 512),
conv_bn_relu(512, 256),
conv_bn_relu(256, 128),
conv_bn_relu(128, 64),
nn.Conv2d(64, channels, 3, padding=1),
)
def forward(self, x):
return self.conv(x)
Maintenant que l'on a notre modèle pour passer d'une étape à une autre, construisons le modèle global :
class DDPM(nn.Module):
def __init__(self,model: nn.Module,betas: Tuple[float, float],n_T: int,criterion: nn.Module = nn.MSELoss()) -> None:
super(DDPM, self).__init__()
self.model = model
# Permet de stocker les ddpm schedules dans le modèle pour accéder aux valeurs plus facilement
for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
self.register_buffer(k, v)
self.n_T = n_T
self.criterion = criterion
# Etape d'entrainement
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Makes forward diffusion x_t, and tries to guess epsilon value from x_t using eps_model.
This implements Algorithm 1 in the paper.
"""
# Génère un entier aléatoire entre 1 et n_T pour choisir un t aléatoire
_ts = torch.randint(1, self.n_T, (x.shape[0],)).to(x.device)
# Génère un bruit aléatoire de la même taille que x
eps = torch.randn_like(x) # eps ~ N(0, 1)
# On applique le bruit gaussien à x pour obtenir x_t
x_t = (self.sqrtab[_ts, None, None, None] * x+ self.sqrtmab[_ts, None, None, None] * eps)
# On va essayer de prédire le bruit epsilon à partir de x_t
pred_eps = self.model(x_t)
return self.criterion(eps, pred_eps)
# Génération d'un échantillon
def sample(self, n_sample: int, size, device) -> torch.Tensor:
# On génère un échantillon aléatoire de taille n_sample à partir d'une distribution normale centrée réduite
x_i = torch.randn(n_sample, *size).to(device) # x_T ~ N(0, 1)
# On va appliquer le processus de diffusion inverse pour générer un échantillon (ça prend du temps
# car on doit appliquer le processus de diffusion à chaque étape)
for i in range(self.n_T, 0, -1):
z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
eps = self.model(x_i)
x_i = (self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])+ self.sqrt_beta_t[i] * z)
return x_i
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ddpm = DDPM(model=model(channels=1), betas=(1e-4, 0.02), n_T=1000)
ddpm.to(device);
Et voilà, notre implémentation simple est terminée !
Entrainement¶
Passons maintenant à l'entrainement du modèle. On va diminuer grandement la taille du dataset (seulement 1000 éléments) mais l'entraînement sera quand même très long. Je ne vous conseille pas d'essayer sauf si vous avez un très bon GPU à disposition.
epoch=100
n_T=1000
optimizer = torch.optim.Adam(ddpm.parameters(), lr=2e-4)
generation=[]
for i in range(0,epoch+1):
ddpm.train()
loss_ema = None
for x, _ in dataloader:
optimizer.zero_grad()
x = x.to(device)
loss = ddpm(x)
loss.backward()
if loss_ema is None:
loss_ema = loss.item()
else:
loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
optimizer.step()
print(f"epoch {i}, loss: {loss_ema:.4f}")
if i % 10 == 0:
ddpm.eval()
with torch.no_grad():
print('ici')
xh = ddpm.sample(4, (1, 28, 28), device)
grid = make_grid(xh, nrow=4)
generation.append(grid)
for grid in generation:
grid_image = grid.permute(1, 2, 0).cpu().numpy()
# Afficher l'image
plt.imshow(grid_image)
plt.axis('off') # Pour masquer les axes
plt.show()