Implementing a Diffusion Model#
In this course, we will step-by-step implement a diffusion model on the MNIST dataset. This course is largely inspired by the minDiffusion project.
from typing import Tuple
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms as T
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
/home/aquilae/anaconda3/envs/dev/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Dataset Preparation#
transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (1.0))])
dataset = MNIST("./../data",train=True,download=True,transform=transform)
print('Taille du dataset :', len(dataset))
print('Taille d\'une image :', dataset[0][0].numpy().shape)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=8)
Taille du dataset : 60000
Taille d'une image : (1, 28, 28)
Noise Scheduling#
For a diffusion model, it is necessary to generate noise at each step. We will therefore create noise value ranges between \(\beta_1\) and \(\beta_2\).
def ddpm_schedules(beta1: float, beta2: float, T: int):
# On vérifie que beta1 et beta2 sont bien dans l'intervalle (0, 1) et que beta1 < beta2
assert beta1 < beta2 < 1.0, "beta1 et beta2 doivent être dans l'intervalle (0, 1)"
# On crée un vecteur de taille T+1 allant de beta1 à beta2 qui échantillonne linéairement l'intervalle [beta1, beta2]
beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
# On calcule toutes les valeurs qui seront nécessaires pour les calculs de l'optimisation
sqrt_beta_t = torch.sqrt(beta_t)
alpha_t = 1 - beta_t
log_alpha_t = torch.log(alpha_t)
alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()
sqrtab = torch.sqrt(alphabar_t)
oneover_sqrta = 1 / torch.sqrt(alpha_t)
sqrtmab = torch.sqrt(1 - alphabar_t)
mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab
return {
"alpha_t": alpha_t, # \alpha_t
"oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t}
"sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t}
"alphabar_t": alphabar_t, # \bar{\alpha_t}
"sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}}
"sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}}
"mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
}
Model Architecture#
We will now create our model. Generally, a U-Net architecture is used, but in practice, other architectures can work. For simplicity, we will use a basic convolutional model. Note that normally, the model takes the step \(t\) as input, but in our simplified version, we will not do this.
def conv_bn_relu(in_channels, out_channels,kernel_size=7, stride=1, padding=3):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size,stride, padding),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU())
class model(nn.Module):
def __init__(self, channels: int) -> None:
super(model, self).__init__()
# Très petit modèle
self.conv = nn.Sequential(
conv_bn_relu(channels, 64),
conv_bn_relu(64, 128),
conv_bn_relu(128, 256),
conv_bn_relu(256, 512),
conv_bn_relu(512, 256),
conv_bn_relu(256, 128),
conv_bn_relu(128, 64),
nn.Conv2d(64, channels, 3, padding=1),
)
def forward(self, x):
return self.conv(x)
Now that we have our model to go from one step to another, let’s build the overall model:
class DDPM(nn.Module):
def __init__(self,model: nn.Module,betas: Tuple[float, float],n_T: int,criterion: nn.Module = nn.MSELoss()) -> None:
super(DDPM, self).__init__()
self.model = model
# Permet de stocker les ddpm schedules dans le modèle pour accéder aux valeurs plus facilement
for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
self.register_buffer(k, v)
self.n_T = n_T
self.criterion = criterion
# Etape d'entrainement
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Makes forward diffusion x_t, and tries to guess epsilon value from x_t using eps_model.
This implements Algorithm 1 in the paper.
"""
# Génère un entier aléatoire entre 1 et n_T pour choisir un t aléatoire
_ts = torch.randint(1, self.n_T, (x.shape[0],)).to(x.device)
# Génère un bruit aléatoire de la même taille que x
eps = torch.randn_like(x) # eps ~ N(0, 1)
# On applique le bruit gaussien à x pour obtenir x_t
x_t = (self.sqrtab[_ts, None, None, None] * x+ self.sqrtmab[_ts, None, None, None] * eps)
# On va essayer de prédire le bruit epsilon à partir de x_t
pred_eps = self.model(x_t)
return self.criterion(eps, pred_eps)
# Génération d'un échantillon
def sample(self, n_sample: int, size, device) -> torch.Tensor:
# On génère un échantillon aléatoire de taille n_sample à partir d'une distribution normale centrée réduite
x_i = torch.randn(n_sample, *size).to(device) # x_T ~ N(0, 1)
# On va appliquer le processus de diffusion inverse pour générer un échantillon (ça prend du temps
# car on doit appliquer le processus de diffusion à chaque étape)
for i in range(self.n_T, 0, -1):
z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
eps = self.model(x_i)
x_i = (self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])+ self.sqrt_beta_t[i] * z)
return x_i
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ddpm = DDPM(model=model(channels=1), betas=(1e-4, 0.02), n_T=1000)
ddpm.to(device);
And that’s it, our simple implementation is complete!
Model Training#
Now let’s move on to training the model. We will significantly reduce the dataset size (only 1000 elements), but training will still be very long. I do not recommend trying this unless you have a very good GPU.
epoch=100
n_T=1000
optimizer = torch.optim.Adam(ddpm.parameters(), lr=2e-4)
generation=[]
for i in range(0,epoch+1):
ddpm.train()
loss_ema = None
for x, _ in dataloader:
optimizer.zero_grad()
x = x.to(device)
loss = ddpm(x)
loss.backward()
if loss_ema is None:
loss_ema = loss.item()
else:
loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
optimizer.step()
print(f"epoch {i}, loss: {loss_ema:.4f}")
if i % 10 == 0:
ddpm.eval()
with torch.no_grad():
print('ici')
xh = ddpm.sample(4, (1, 28, 28), device)
grid = make_grid(xh, nrow=4)
generation.append(grid)
for grid in generation:
grid_image = grid.permute(1, 2, 0).cpu().numpy()
# Afficher l'image
plt.imshow(grid_image)
plt.axis('off') # Pour masquer les axes
plt.show()