扩散模型的实现#
在本课程中,我们将逐步在 MNIST 数据集 上实现一个扩散模型。 本课程主要参考了项目 minDiffusion 的内容。
from typing import Tuple
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms as T
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
/home/aquilae/anaconda3/envs/dev/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
数据集准备#
transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (1.0))])
dataset = MNIST("./../data",train=True,download=True,transform=transform)
print('Taille du dataset :', len(dataset))
print('Taille d\'une image :', dataset[0][0].numpy().shape)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=8)
Taille du dataset : 60000
Taille d'une image : (1, 28, 28)
噪声调度#
在扩散模型中,每一步都需要生成噪声。 我们将创建一个噪声值范围,其取值在 \(\beta_1\) 到 \(\beta_2\) 之间。
def ddpm_schedules(beta1: float, beta2: float, T: int):
# On vérifie que beta1 et beta2 sont bien dans l'intervalle (0, 1) et que beta1 < beta2
assert beta1 < beta2 < 1.0, "beta1 et beta2 doivent être dans l'intervalle (0, 1)"
# On crée un vecteur de taille T+1 allant de beta1 à beta2 qui échantillonne linéairement l'intervalle [beta1, beta2]
beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
# On calcule toutes les valeurs qui seront nécessaires pour les calculs de l'optimisation
sqrt_beta_t = torch.sqrt(beta_t)
alpha_t = 1 - beta_t
log_alpha_t = torch.log(alpha_t)
alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()
sqrtab = torch.sqrt(alphabar_t)
oneover_sqrta = 1 / torch.sqrt(alpha_t)
sqrtmab = torch.sqrt(1 - alphabar_t)
mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab
return {
"alpha_t": alpha_t, # \alpha_t
"oneover_sqrta": oneover_sqrta, # 1/\sqrt{\alpha_t}
"sqrt_beta_t": sqrt_beta_t, # \sqrt{\beta_t}
"alphabar_t": alphabar_t, # \bar{\alpha_t}
"sqrtab": sqrtab, # \sqrt{\bar{\alpha_t}}
"sqrtmab": sqrtmab, # \sqrt{1-\bar{\alpha_t}}
"mab_over_sqrtmab": mab_over_sqrtmab_inv, # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
}
模型架构#
现在我们将构建模型。通常采用 U-Net 架构,但实践中其他架构也可行。 为简化流程,我们将使用一个基础的卷积模型。 需要注意的是,标准模型会将步骤 \(t\) 作为输入,但在本简化版本中暂不考虑。
def conv_bn_relu(in_channels, out_channels,kernel_size=7, stride=1, padding=3):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size,stride, padding),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU())
class model(nn.Module):
def __init__(self, channels: int) -> None:
super(model, self).__init__()
# Très petit modèle
self.conv = nn.Sequential(
conv_bn_relu(channels, 64),
conv_bn_relu(64, 128),
conv_bn_relu(128, 256),
conv_bn_relu(256, 512),
conv_bn_relu(512, 256),
conv_bn_relu(256, 128),
conv_bn_relu(128, 64),
nn.Conv2d(64, channels, 3, padding=1),
)
def forward(self, x):
return self.conv(x)
现在我们已构建了步骤间转换的模型,接下来将其集成为完整模型:
class DDPM(nn.Module):
def __init__(self,model: nn.Module,betas: Tuple[float, float],n_T: int,criterion: nn.Module = nn.MSELoss()) -> None:
super(DDPM, self).__init__()
self.model = model
# Permet de stocker les ddpm schedules dans le modèle pour accéder aux valeurs plus facilement
for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
self.register_buffer(k, v)
self.n_T = n_T
self.criterion = criterion
# Etape d'entrainement
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Makes forward diffusion x_t, and tries to guess epsilon value from x_t using eps_model.
This implements Algorithm 1 in the paper.
"""
# Génère un entier aléatoire entre 1 et n_T pour choisir un t aléatoire
_ts = torch.randint(1, self.n_T, (x.shape[0],)).to(x.device)
# Génère un bruit aléatoire de la même taille que x
eps = torch.randn_like(x) # eps ~ N(0, 1)
# On applique le bruit gaussien à x pour obtenir x_t
x_t = (self.sqrtab[_ts, None, None, None] * x+ self.sqrtmab[_ts, None, None, None] * eps)
# On va essayer de prédire le bruit epsilon à partir de x_t
pred_eps = self.model(x_t)
return self.criterion(eps, pred_eps)
# Génération d'un échantillon
def sample(self, n_sample: int, size, device) -> torch.Tensor:
# On génère un échantillon aléatoire de taille n_sample à partir d'une distribution normale centrée réduite
x_i = torch.randn(n_sample, *size).to(device) # x_T ~ N(0, 1)
# On va appliquer le processus de diffusion inverse pour générer un échantillon (ça prend du temps
# car on doit appliquer le processus de diffusion à chaque étape)
for i in range(self.n_T, 0, -1):
z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
eps = self.model(x_i)
x_i = (self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])+ self.sqrt_beta_t[i] * z)
return x_i
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ddpm = DDPM(model=model(channels=1), betas=(1e-4, 0.02), n_T=1000)
ddpm.to(device);
这样,我们的简化实现就完成了!
模型训练#
接下来进行模型训练。我们将大幅缩减数据集规模(仅使用 1000 个样本),但训练过程仍然非常耗时。 建议:除非配备高性能 GPU,否则不建议尝试。
epoch=100
n_T=1000
optimizer = torch.optim.Adam(ddpm.parameters(), lr=2e-4)
generation=[]
for i in range(0,epoch+1):
ddpm.train()
loss_ema = None
for x, _ in dataloader:
optimizer.zero_grad()
x = x.to(device)
loss = ddpm(x)
loss.backward()
if loss_ema is None:
loss_ema = loss.item()
else:
loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
optimizer.step()
print(f"epoch {i}, loss: {loss_ema:.4f}")
if i % 10 == 0:
ddpm.eval()
with torch.no_grad():
print('ici')
xh = ddpm.sample(4, (1, 28, 28), device)
grid = make_grid(xh, nrow=4)
generation.append(grid)
for grid in generation:
grid_image = grid.permute(1, 2, 0).cpu().numpy()
# Afficher l'image
plt.imshow(grid_image)
plt.axis('off') # Pour masquer les axes
plt.show()