扩散模型的实现

扩散模型的实现#

在本课程中,我们将逐步在 MNIST 数据集 上实现一个扩散模型。 本课程主要参考了项目 minDiffusion 的内容。

from typing import Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision import transforms as T
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
/home/aquilae/anaconda3/envs/dev/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

数据集准备#

transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (1.0))])

dataset = MNIST("./../data",train=True,download=True,transform=transform)
print('Taille du dataset :', len(dataset))
print('Taille d\'une image :', dataset[0][0].numpy().shape)

dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=8)
Taille du dataset : 60000
Taille d'une image : (1, 28, 28)

噪声调度#

在扩散模型中,每一步都需要生成噪声。 我们将创建一个噪声值范围,其取值在 \(\beta_1\)\(\beta_2\) 之间。

def ddpm_schedules(beta1: float, beta2: float, T: int):

  # On vérifie que beta1 et beta2 sont bien dans l'intervalle (0, 1) et que beta1 < beta2
  assert beta1 < beta2 < 1.0, "beta1 et beta2 doivent être dans l'intervalle (0, 1)"

  # On crée un vecteur de taille T+1 allant de beta1 à beta2 qui échantillonne linéairement l'intervalle [beta1, beta2]
  beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
  
  # On calcule toutes les valeurs qui seront nécessaires pour les calculs de l'optimisation
  sqrt_beta_t = torch.sqrt(beta_t)
  alpha_t = 1 - beta_t
  log_alpha_t = torch.log(alpha_t)
  alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

  sqrtab = torch.sqrt(alphabar_t)
  oneover_sqrta = 1 / torch.sqrt(alpha_t)

  sqrtmab = torch.sqrt(1 - alphabar_t)
  mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab

  return {
    "alpha_t": alpha_t,  # \alpha_t
    "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
    "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
    "alphabar_t": alphabar_t,  # \bar{\alpha_t}
    "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
    "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
    "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
  }

模型架构#

现在我们将构建模型。通常采用 U-Net 架构,但实践中其他架构也可行。 为简化流程,我们将使用一个基础的卷积模型。 需要注意的是,标准模型会将步骤 \(t\) 作为输入,但在本简化版本中暂不考虑

def conv_bn_relu(in_channels, out_channels,kernel_size=7, stride=1, padding=3):
  return nn.Sequential(
    nn.Conv2d(in_channels, out_channels, kernel_size,stride, padding),
    nn.BatchNorm2d(out_channels),
    nn.LeakyReLU())

class model(nn.Module):

  def __init__(self, channels: int) -> None:
    super(model, self).__init__()
    # Très petit modèle
    self.conv = nn.Sequential(  
      conv_bn_relu(channels, 64),
      conv_bn_relu(64, 128),
      conv_bn_relu(128, 256),
      conv_bn_relu(256, 512),
      conv_bn_relu(512, 256),
      conv_bn_relu(256, 128),
      conv_bn_relu(128, 64),
      nn.Conv2d(64, channels, 3, padding=1),
    )

  def forward(self, x):
    return self.conv(x)

现在我们已构建了步骤间转换的模型,接下来将其集成为完整模型

class DDPM(nn.Module):
    def __init__(self,model: nn.Module,betas: Tuple[float, float],n_T: int,criterion: nn.Module = nn.MSELoss()) -> None:
        super(DDPM, self).__init__()
        self.model = model

        # Permet de stocker les ddpm schedules dans le modèle pour accéder aux valeurs plus facilement
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T
        self.criterion = criterion

    # Etape d'entrainement
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Makes forward diffusion x_t, and tries to guess epsilon value from x_t using eps_model.
        This implements Algorithm 1 in the paper.
        """
        
        # Génère un entier aléatoire entre 1 et n_T pour choisir un t aléatoire
        _ts = torch.randint(1, self.n_T, (x.shape[0],)).to(x.device) 
        # Génère un bruit aléatoire de la même taille que x
        eps = torch.randn_like(x)  # eps ~ N(0, 1)

            # On applique le bruit gaussien à x pour obtenir x_t
        x_t = (self.sqrtab[_ts, None, None, None] * x+ self.sqrtmab[_ts, None, None, None] * eps)  
        # On va essayer de prédire le bruit epsilon à partir de x_t
        pred_eps = self.model(x_t)
        return self.criterion(eps, pred_eps)

    # Génération d'un échantillon
    def sample(self, n_sample: int, size, device) -> torch.Tensor:
        
        # On génère un échantillon aléatoire de taille n_sample à partir d'une distribution normale centrée réduite
        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1)

        # On va appliquer le processus de diffusion inverse pour générer un échantillon (ça prend du temps 
        # car on doit appliquer le processus de diffusion à chaque étape)
        for i in range(self.n_T, 0, -1):
            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
            eps = self.model(x_i)
            x_i = (self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])+ self.sqrt_beta_t[i] * z)

        return x_i

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ddpm = DDPM(model=model(channels=1), betas=(1e-4, 0.02), n_T=1000)
ddpm.to(device);

这样,我们的简化实现就完成了!

模型训练#

接下来进行模型训练。我们将大幅缩减数据集规模(仅使用 1000 个样本),但训练过程仍然非常耗时。 建议:除非配备高性能 GPU,否则不建议尝试。

epoch=100
n_T=1000
optimizer = torch.optim.Adam(ddpm.parameters(), lr=2e-4)
generation=[]
for i in range(0,epoch+1):
  ddpm.train()
  loss_ema = None
  for x, _ in dataloader:
    optimizer.zero_grad()
    x = x.to(device)
    loss = ddpm(x)
    loss.backward()
    if loss_ema is None:
      loss_ema = loss.item()
    else:
      loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
    optimizer.step()
  print(f"epoch {i}, loss: {loss_ema:.4f}")
  if i % 10 == 0:
    ddpm.eval()
    with torch.no_grad():
      print('ici')
      xh = ddpm.sample(4, (1, 28, 28), device)
      grid = make_grid(xh, nrow=4)
      generation.append(grid)      
for grid in generation:
  grid_image = grid.permute(1, 2, 0).cpu().numpy()
  # Afficher l'image
  plt.imshow(grid_image)
  plt.axis('off')  # Pour masquer les axes
  plt.show()