ℹ️ 写在前面

本文尚未完成,但近期不会补充完整。

这里记录一下还需要写的部分:

  • diffusion.py 的代码解析
  • DiT 模型小实验的整体运行效果
  • 参考文献也需要补充

本文内容基于论文 Scalable Diffusion Models with Transformers

但文章内容并不是论文阅读笔记,而是对 DiT 的入门介绍。未来有机会可以写一个详细的阅读笔记。

扩散模型介绍

扩散模型(Denoising Diffusion Probabilistic Models,DDPMs)在图像/音频/视频生成方面取得了显著的成果。

本文中,采用离散时间(潜变量模型)(discrete-time (lantent variable model))的视角,事实上有多种关于扩散模型的观点,可以都去了解一下。

随机微分方程(SDE)、得分匹配(Score Matching)、朗之万动力学(Langevin Dynamics)、变分推断视角(Variational Inference)……

我们常说的机器学习,或者更准确一些,监督学习中,我们做的是“判别”:

  • 输入:一张猫的图片

  • 输出:标签”cat“

  • 本质:学习 p(y|x) (给定图片,预测类别)

而扩散模型做的是”生成“,可以理解为与监督学习相反的问题:

  • 输入:随机噪声

  • 输出:一张猫的图片

  • 本质:学习 p(x) (数据本身的分布)

下面,让我们详细看看扩散模型是如何完成它的工作的:

前向过程:数据 → 噪声

我们定义一个固定的、不需要学习的前向过程Forward Process),把真实图像 $x_0$ 逐步变成纯噪声 $x_T$ 。

为什么要做这个?

我们在 前向过程 中人为构造一条从数据到噪声的渐进式退化路径。

扩散模型可以理解为一个去噪网络,它的工作流程就是对着一张图片(初始是随机噪声),一步一步去掉噪声,最后得到我们要生成的目标图片。

那么,我们现在手上有原始图片,我们希望模型学会如何去噪,需要从这一张图片构造出足够多的训练数据供扩散模型学习。

我们把原始图片一步步加噪,得到一系列图片(每个图片即对应一个时间步 $t$ ),每个图片都比前一张有更多噪声,直到最后成为纯噪声。

然后我们把这一系列图片反过来看,就是一个从纯噪声一步步变成目标图片的过程,这就是扩散模型需要的训练数据。它会从这一系列数据中学会去噪的技巧,最后能够按需求生成图片。

总而言之,这就是前向过程可以理解为训练数据的构造方式

这个数据到噪声的过程是一步步进行的。我们看到其中一步:

单步转移

从第 $t-1$ 步到 第 $t$ 步,我们做的就是依赖当前状态,加一点点高斯噪声:

$$ q(x_t|x_{t-1})=\mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1},\beta_t \mathbf{I}) $$
  • $\beta_t$ 是一个很小的数(比如 0.0001 ~ 0.02),它表示这一步“加多少噪声”。

    一般来说

    • $t$ 越小,$\beta_t$ 越小,$x_t$ 越接近 $x_{t-1}$ (添加的噪声少)
    • $t$ 越大,$\beta_t$ 越大,$x_t$ 越接近纯噪声
  • $\sqrt{1-\beta_t}$ 是信号保留比例 ,$\sqrt{\beta_t}$ 是噪声加入比例

  • 这个式子本身表示:

    $x_t$ 是从以 $\sqrt{1-\beta_t}x_{t-1}$ 为均值,$\beta_t \mathbf{I}$ 为方差的高斯分布中采样得到的。

    等价于:$x_t=\sqrt{1-\beta_t}x_{t-1}+\sqrt{\beta_t}\epsilon$ ,其中 $\epsilon \sim \mathcal{N}(0, \mathbf{I})$


这里涉及好多数学概念,我们来解释一下:

  1. 马尔可夫性质

    对于一个随机过程,如果未来状态只依赖于当前状态,而与过去状态无关,我们说它具有马尔可夫性质Markov Property

    显然,前向过程是具有马尔可夫性质的。

  2. $x_t$ 是一个张量(Tensor),可以看作是展平后的向量或保持结构的多维数组

    • 对于MNIST手写数字数据集,每个图片是28×28的单通道灰度图,在数学上,$x_t \in \mathbb{R}^{1 \times 28 \times 28}$

    • 对于 ImageNet 等彩色图像数据集,每个图片有3通道RGB,256×256分辨率,在数学上,$x_t \in \mathbb{R}^{3 \times 256 \times 256}$

  3. $\mathbf{I}$ 是单位矩阵Identity Matrix),对角线元素全为 1 ,其余元素全为 0 的方阵。

    • $\beta_t \mathbf{I}$ 表示协方差矩阵是对角阵,意味着图像的每个像素独立地添加高斯噪声,像素之间没有相关性

    • 各向同性:所有维度的方差都是 $\beta_t$ ,没有某个像素被特别对待

  4. $q(\cdot|\cdot)$ 是条件概率分布,与反向过程 $p_\theta$

    $q(x_t | x_{t-1})$ 表示:给定 $x_{t-1}$ 时,$x_t$ 的条件分布

  5. 高斯分布 = 正态分布,$\mathcal{N}(0, \mathbf{I})$ 是标准高斯分布,均值为0,方差为1,各维度独立


重参数化技巧(Reparameterization Trick)

我们拿着 $x_0$ ,要通过上面的式子算到 $x_T$ ,需要循环迭代 $T$ 次。

这时候,就要数学工具登场了:多次叠加的高斯噪声,等价于一次更加大的高斯噪声。这一点可以通过数学推导得到。

由此,我们可以直接从 $x_0$ 跳到任意第 $t$ 步:

$$ q(x_t|x_0)=\mathcal{N}(x_t;\sqrt{\overline{\alpha}_t}x_0,(1-\overline{\alpha}_t)\mathbf{I}) $$

其中,$\overline{a}_t$ 是累计信号比例:

$$ \overline{a}_t = \prod_{i=1}^{t}(1-\beta_i) $$

$\sqrt{\overline{\alpha}_t}$ 从 1 衰减到 0,意味着随着 $t$ 增大,$x_0$ 的信号越来越少,噪声越来越多

最后我们得到:

$$ x_t=\sqrt{\overline{\alpha}_t}x_0 + \sqrt{1-\overline{\alpha}_t} \epsilon $$

我们会在代码中看到这个式子的应用

至此,我们得到了每个 $t$ 对应的一个 $x_t$ ,分别对应一个时间步时的加噪图片

反向过程:学习“去噪”的神经网络

与固定的前向过程不同,反向过程 $q(x_{t-1}|x_t)$ 是未知的。我们用神经网络来近似:

$$ p_\theta(x_{t-1}|x_t) = \mathcal{N}(x_{t-1};\mu_\theta(x_t,t),\Sigma_\theta(x_t,t)) $$

该式子表示:给定当前时刻 $t$ 的带噪图像 $x_t$ ,上一时刻(更清晰)的图像 $x_{t-1}$ 服从一个由神经网络参数化的高斯分布。

其中 $\theta$ 表示神经网络的权重,我们看到有两个由神经网络训练得到的变量:

  1. $\mu_\theta(x_t,t)$ :神经网络预测的均值

  2. $\Sigma_\theta(x_t,t)$:神经网络预测的方差

也就是说,我们给定 $x_t$ ,神经网络需要学习计算 $\mu_\theta$ 和 $\Sigma_\theta$ ,然后得到对应高斯分布,并采样得到 $x_{t-1} \sim \mathcal{N}(\mu_\theta,\Sigma_\theta)$ 。

然后一步一步迭代,最后可以通过神经网络得到 $x_0$ ,即清晰的目标图像

简化假设

DDPM 原论文做了两个核心简化,让问题变得可解

  1. 方差固定:$\Sigma_\theta$ 不学习,直接用 $\beta_t$ 或 $\tilde{\beta}_t$ (后验方差)

    原论文发现,学习方差对生成质量提升有限,同时又会显著增加训练难度

    因此,作者给出两个方案:

    1. $\Sigma_\theta(x_t,t)=\beta_t\mathbf{I}$ ,使用前向过程的噪声强度

      • $\beta_t \in (0,1)$ 是前向过程中第 $t$ 步人为设定的噪声方差

      • 它表示从 $x_{t-1}$ 到 $x_t$ 时,新注入的高斯噪声的强度

      • 本身与数据无关,由调度方案决定

    2. $\Sigma_\theta(x_t, t) = \tilde{\beta}_t \mathbf{I}$ ,使用真实后验的方差

      • $\tilde{\beta_t} = \frac{1-\overline{\alpha}_{t-1}}{1-\overline{\alpha}_t} \cdot \beta_t$ ,即真实后验分布 $q(x_{t-1} | x_t,x_0)$ 的方差

      • 表示:如果我们知道 $x_0$ ,从 $x_t$ 反推 $x_{t-1}$ 时的不确定性

      • 依赖累计信号 $\overline{\alpha}_t$ ,与数据分布有关

      后验是一个概率分布,而不是单一数值。

      • 均值:最可能的估计值

      • 方差:这个估计的不确定性有多大

      结合这里的场景,反向过程中,我们已知 $x_t$ ,需要推断 $x_{t-1}$。

      对于后验分布 $q(x_{t-1} | x_t,x_0)$ ,表示此时我们在上述情况的基础上,还知道 $x_0$ 。由于前向过程是高斯分布,这个后验有解析解,可以得到方差 $\tilde{\beta_t} = \frac{1-\overline{\alpha}_{t-1}}{1-\overline{\alpha}_t} \cdot \beta_t$ 。

      这个结论可以通过数学推导得到,未来可以补充

      但实际上,反向过程的真实后验应该是 $q(x_{t-1}|x_t)$ ,这个分布就没有解析解,因为反向过程中 $x_0$ 是未知的。

      对于这个方案,我们就是拿已知 $x_0$ 时的方差 $\tilde{\beta}_t$ 来近似未知 $x_0$ 时的真实方差,计算简单且效果足够好。

    实验表明,两种方案效果接近,且都不比学习方差差太多。

    这种简化方案还有更多有力的理论基础在,未来可以补充。

  2. 只学习均值 $\mu_\theta$ ,而且通过预测噪声来间接学习均值

从噪声预测到均值

真实的后验均值如下所示:

$$ \tilde{\mu}_t(x_t,x_0)=\frac{\sqrt{\overline{\alpha}_{t-1}}\beta_t}{1-\overline{\alpha}_t}x_0+\frac{\sqrt{\alpha_t}(1-\overline{\alpha}_{t-1})}{1-\overline{\alpha}_{t-1}} x_t $$

这个公式描述的分布是:$q(x_{t-1}|x_t,x_0)$ ,

这里的说的后验均值指的是:理想状态下已知 $x_0$ 和 $x_t$ 时 $x_{t-1}$ 的期望值。

也就是说,这里也是一个近似,并不是真实的反向过程。

具体推导用到了贝叶斯公式,未来有机会补充一下。

但在反向过程中, $x_0$ 未知。我们可以用前向公式的逆运算得到 $x_0$ :

$$ x_0=\frac{x_t-\sqrt{1-\overline{\alpha}_t}\epsilon}{\sqrt{\overline{\alpha}_t}} $$

这里有一个会困惑的点,这个 $x_0$ 的式子是恒等式,而不是近似

用大白话讲,$x_0$ 就是这么算出来的。

由此,均值可以用 $\epsilon$ 表示,神经网络需要训练的变量也就从均值转变为了噪声(见下)

代入得到:

$$ \tilde{\mu}_t=\frac{1}{\sqrt{\alpha_t}}\big( x_t - \frac{1-\alpha_t}{\sqrt{1-\overline{\alpha_t}}} \epsilon \big) $$

至此,我们只要训练一个神经网络 $\epsilon_\theta$ 来预测噪声 $\epsilon$ ,就能算出均值$\mu_\theta$,进而采样 $x_{t-1}$ :

$$ \mu_\theta=\frac{1}{\sqrt{\alpha_t}}\big( x_t - \frac{1-\alpha_t}{\sqrt{1-\overline{\alpha_t}}} \epsilon_\theta(x_t,t) \big) $$

对于这个式子有几点说明一下:

  • $\alpha_t$ 和 $\tilde{\alpha}_t$ 的区别:

    • $\alpha_t=1-\beta_t$ ,单步信号保留系数(第 $t$ 步这一时刻)

    • $\overline{\alpha}_t=\prod^t_{i=1} \alpha_i$ ,累积信号保留系数(从第 1 步到第 $t$ 步的连乘)

  • $\epsilon_\theta(x_t,t)$ 表示我们的神经网络输入是 $x_t$ 和 $t$ 。这点会在代码中体现。

训练目标

经过前文一系列数学变换,我们最后得到简洁的神经网络训练目标

$$ \mathcal{L}=\mathbb{E}_{x_0,\epsilon,t}[||\epsilon-\epsilon_\theta(x_t,t)||^2] $$
  • $\mathbb{E}$ 是期望符号,这里表示对三个随机变量($x_0,\epsilon,t$)求期望

    即在所有可能的训练图像、所有可能的噪声、所有可能的时间步上,求平均损失

  • $||\cdot||^2$ 表示均方误差 MSE

这里实际上涉及一些数学概念。作为生成模型,扩散模型直观的训练目标是生成图像尽可能像真实目标图像,这在数学上是需要比较两个复杂分布的相似度。

但经过上述的简化,我们把训练目标变成了预测噪声,可以理解为向量与向量之间比较。这无疑大大减轻了问题的解决难度。

我们称这个数学变换过程为变分下界(Variational Lower Bound, ELBO)简化

我们来直观理解一下这个神经网络的训练流程:

  1. 随机采样一个真实图像 $x_0$

  2. 随机采样一个时间步 $t \in [1,T]$

  3. 随机采样噪声 $\epsilon \sim \mathcal{N}(0,\mathbf{I})$

  4. 生成 $x_t=\sqrt{\overline{\alpha}_t}x_0 + \sqrt{1-\overline{\alpha}_t}\epsilon$

  5. 模型预测 $\epsilon_\theta(x_t,t)$

  6. 计算损失 = 真实噪声与预测噪声的均方误差(MSE)

至此,我们训练得到了一个会去噪的神经网络。

采样过程——如何生成图像

训练好模型后,生成图像的过程是这样的:

  1. 从纯噪声 $x_T \sim \mathcal{N}(0,\mathbf{I})$ 开始

  2. 从 $t=T$ 到 $t=1$,循环:

    1. 模型预测噪声 $\epsilon_\theta(x_t,t)$

    2. 计算 $x_0$ 的预测

    3. 计算均值 $\mu_\theta$

    4. 采样 $x_{t-1}=\mu_\theta+\sqrt{\beta_t}z$

      $z$ 是标准高斯噪声,$t=1$时不加噪声

    这里,根据前面提到的后验分布,$\mu_\theta$ 即为网络认为最可能的 $x_{t-1}$

    但这里我们并不直接输出,而是加上了一个新采样的高斯噪声 $\sqrt{\beta_t}z$

    因为生成模型的生成目标并不是确定性的,我们希望生成多样化的目标,所以加上噪声能够避免生成过程变成确定性映射。

    至于 $t=1$ 时不加噪声,是希望最后的结果是清晰的

  3. 返回 $x_0$

DiT 架构详解

代码实践

这里,我们用 MNIST 数据集(手写数字)来做一个简单的小案例,用扩散模型生成指定数字的单通道灰度图


下面的 “代码详细解释” 中,我会把我自己在代码中不懂或困惑的地方进行详细说明,可能会有些冗余。

读者可按需阅读或跳过。


config.py

集中管理所有超参数,避免魔法数字散落在代码中

📁 config.py 完整代码
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from dataclasses import dataclass

@dataclass
class DiTConfig:
    """DiT 模型架构配置"""
    image_size: int = 28    # MNIST 尺寸
    patch_size: int = 4     # 4×4 patches -> 7×7=49 tokens
    in_channels: int = 1    # MNIST 灰度图
    hidden_size: int = 256  # Transformer 隐藏维度
    depth: int = 4          # Transformer 层数
    num_heads: int = 4      # 注意力头数
    mlp_ratio: float = 4.0  # MLP 隐藏层倍数
    num_classes: int = 10   # MNIST 类别数
    dropout: float = 0.1    # Dropout 概率

    @property
    def num_patches(self) -> int:
        """计算 patch 数量"""
        return (self.image_size // self.patch_size) ** 2
    
@dataclass
class DiffusionConfig:
    """扩散过程配置"""
    timesteps: int = 1000       # 总扩散步数
    beta_start: float = 1e-4    # 起始噪声强度
    beta_end: float = 0.02      # 终止噪声强度
    schedule: str = "linear"    # beta 调度方式:linear/cosine

    # 采样配置
    sample_steps: int = 1000    # 采样步数(可小于 timesteps 用于加速)
    cfg_scale: float = 2.0      # Classifier-free guidance 强度
    cfg_dropout: float = 0.1    # 训练时条件丢弃概率

    device: str = "cpu"

@dataclass
class TrainConfig:
    """训练配置"""
    # 数据
    batch_size: int = 64
    num_workers: int = 4    # 数据加载线程

    # 优化器
    learning_rate: float = 1e-4 
    weight_decay: float = 0.03
    epochs: int = 50

    # 日志与保存
    log_every: int = 100    # 每 N batch 打印日志
    sample_every: int = 5   # 每 N epoch 生成样本
    save_every: int = 10    # 每 N epoch 保存

    # 路径
    data_dir: str = "./data"
    checkpoint_dir: str = "./checkpoints"
    sample_dir: str = "./samples"

    # 设备与恢复
    device: str = "cpu"
    resume: bool = False        # 是否从 checkpoint 恢复
    checkpoint_path: str = ""   # 指定恢复路径

# 组合配置(方便传递)
@dataclass
class Config:
    model: DiTConfig = DiTConfig()
    diffusion: DiffusionConfig = DiffusionConfig()
    train: TrainConfig = TrainConfig()

代码细节解释

  • Patch 与 Batch 区分

    • Patch:一次性送进模型的一组样本

    • Batch:将一张大图切成小方块,每个方块是一个 token

    • 对于我们的例子:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      
      class DiTConfig:
          """DiT 模型架构配置"""
          ……
          patch_size: int = 4     # 4×4 patches -> 7×7=49 tokens
      
      class TrainConfig:
          """训练配置"""
          # 数据
          batch_size: int = 64
      
      • 一张图:28×28;Patch 大小:4×4;切成 7×7=49 个patches

      • Batch=64,一次处理 64 张图

  • Beta 控制每一步加多少噪声

    1
    2
    3
    4
    
    class DiffusionConfig:
        """扩散过程配置"""
        ……
        schedule: str = "linear"    # beta 调度方式:linear/cosine
    
    • Linear:直线增长,均匀加噪;简单,但后期可能加噪过快

    • Cosine:余弦曲线,先慢后快;更平滑,训练更稳定

  • 采样 可以理解为 生成图片的过程

    • 训练时:从真实图像出发,前向加噪学习

    • 采样时:从纯噪声出发,反向去噪生成

    • 因为每步都从概率分布(高斯分布)中随机抽取数值,不是确定性计算

  • 权重衰减Weight Decay)是防止模型过拟合的正则化手段

    • 普通梯度下降:参数=参数-学习率×梯度

    • 带上权重衰减:参数=参数-学习率×梯度 - 学习率×weight_decay×参数

      1
      2
      3
      4
      5
      6
      
      class TrainConfig:
          """训练配置"""
          ……
          # 优化器
          ……
          weight_decay: float = 0.03
      
    • 强迫模型参数保持小数值,避免过于依赖“一家独大”的参数,提升模型泛化能力

  • @dataclass 自动为类生成样板代码(__init____repr____eq__),减少重复代码

utils.py

纯工具函数,无状态,可被任何模块导入

📁 utils.py 完整代码
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import math
import torch
from typing import Optional
import matplotlib.pyplot as plt
import os

def timestep_embedding(timesteps: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
    """
    正弦位置编码,将标量时间步嵌入到 dim 维向量。

    基于 Transformer 原始论文的位置编码,但用于连续时间步。

    Args:
        timesteps: [B] 或 [B, 1] 的时间步张量
        dim: 输入嵌入维度
        max_period: 正弦/余弦的最大周期

    Returns:
        embedding: [B, dim] 时间嵌入 
    """

    # 确保是 1D
    if timesteps.ndim == 0:
        timesteps = timesteps.unsqueeze(0)
    if timesteps.ndim > 1:
        timesteps = timesteps.view(-1)

    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(0, half, dtype=torch.float32) / half
    ).to(timesteps.device)

    args = timesteps[:, None].float() * freqs[None, :]                  # [B, half]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)   # [B, dim]

    # 奇数维度时补零
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)

    return embedding

def patchify(x: torch.Tensor, patch_size: int) -> torch.Tensor:
    """
    将图像分割为 patches 并展平为序列。

    Args:
        x: [B, C, H, W] 图像
        patch_size: patch 边长

    Returns: 
        patches: [B, N, C*patch_size^2] 其中 N = (H//p)*(W//p)
    """
    B, C, H, W = x.shape
    assert H % patch_size == 0 and W % patch_size == 0, "图像尺寸必须能被 patch_size 整除"

    # 使用 unfold 提取滑动窗口,然后 reshape
    # [B, C, H//p, p, W//p, p] -> [B, H//p, W//p, C, p, p]
    x = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    x = x.permute(0, 2, 3, 1, 4, 5).contiguous()

    # [B, N, C*p*p]
    patches = x.view(B, -1, C * patch_size * patch_size)
    return patches

def unpatchify(patches: torch.Tensor, patch_size: int, channels: int, img_size: int) -> torch.Tensor:
    """
    将 patch 序列还原为图像。

    Args: 
        patches: [B, N, C*patch_size^2]
        patch_size: patch 边长
        channels: 图像通道数
        img_size: 图像边长

    Returns:
        x: [B, C, H, W]
    """
    B = patches.shape[0]
    H = W = img_size // patch_size
    C = channels

    # [B, N, C*p*p] -> [B, H, W, C, p, p]
    x = patches.view(B, H, W, C, patch_size, patch_size)

    # [B, C, H, p, W, p] -> [B, C, H*p, W*p]
    x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
    x = x.view(B, C, H * patch_size, W * patch_size)

    return x

def normalize_image(x: torch.Tensor, 
                    mean: Optional[tuple] = None, 
                    std: Optional[tuple] = None) -> torch.Tensor:
    """
    图像归一化到 [-1, 1] 或标准化。

    Args:
        x: [B, C, H, W], 值域 [0, 1]
        mean/std: 若为 None,直接缩放到 [-1, 1]

    Returns:
        归一化后的图像
    """

    if mean is None:
        return x * 2.0 - 1.0
    else:
        # 标准标准化
        mean = torch.tensor(mean, device=x.device).view(1, -1, 1, 1)
        std = torch.tensor(std, device=x.device).view(1, -1, 1, 1)
        return (x - mean) / std

def denormalize_image(x: torch.Tensor, 
                      mean: Optional[tuple] = None, 
                      std: Optional[tuple] = None) -> torch.Tensor:
    """
    反归一化,将图像还原到 [0, 1] 用于可视化
    """
    if mean is None:
        return (x + 1.0) / 2.0
    else:
        mean = torch.tensor(mean, device=x.device).view(1, -1, 1, 1)
        std = torch.tensor(std, device=x.device).view(1, -1, 1, 1)
        return x * std + mean
    
def save_image_grid(images: torch.Tensor, 
                    path: str, 
                    nrow: int = 10, 
                    normalize: bool = True):
    """
    保存图像网格。

    Args:
        images: [N, C, H, W] 或 [N, H, W] 图像张量
        path: 保存路径
        nrow: 每行图像数
    """

    os.makedirs(os.path.dirname(path), exist_ok=True)

    if images.ndim == 3:
        images = images.unsqueeze(1)    # 添加通道维

    if normalize: 
        images = denormalize_image(images)
    
    images = images.clamp(0, 1).cpu().numpy()

    N = images.shape[0]
    ncol = math.ceil(N / nrow)

    fig, axes = plt.subplot(ncol, nrow, figsize=(nrow * 1.5, ncol * 1.5))
    if ncol == 1:
        axes = axes.reshape(1, -1)
    if nrow == 1:
        axes = axes.reshape(-1, 1)

    for idx in range(N):
        i, j = idx // nrow, idx % nrow
        ax = axes[i, j]
        img = images[idx]

        # 单通道灰度图
        if img.shape[0] == 1:
            ax.imshow(img[0], cmap='gray')
        else:
            ax.imshow(img.transpose(1, 2, 0))
        ax.axis('off')
    
    for idx in range(N, ncol * nrow):
        i, j = idx // nrow, idx % nrow
        axes[i, j].axis('off')

    plt.tight_layout()
    plt.savefig(path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Save image grid to {path}")

def count_parameters(model: torch.nn.Module) -> int:
    """统计模型可训练参数量"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_beta_schedule(schedule: str, timesteps: int, beta_start: float, beta_end: float) -> torch.Tensor:
    """
    获取 beta 噪声调度

    Args:
        schedule: "linear" 或 "cosine"
        timesteps: 总步数
        beta_start/end: 起始/终止值

    Returns:
        betas: [timesteps]
    """
    if schedule == "linear":
        return torch.linspace(beta_start, beta_end, timesteps)
    elif schedule == "cosine":
        # Improved DDPM 的余弦调度
        s = 0.008
        steps = timesteps + 1
        x = torch.linspace(0, timesteps, steps)
        alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 0.0001, 0.9999)
    else:
        raise ValueError(f"Unknown schedule: {schedule}")

代码细节解释

这里挑选一部分我感到困惑的内容来详细说明。

正弦位置编码 timestep_embedding()

把整数时间步(如 0, 1, 2, …, 999)编码为高维向量(如 256 维),让神经网络“感知”当前处于去噪的哪个阶段。

比如第100步与第900步的去噪策略完全不同,因此我们需要把时间步变成一个“特征向量”来告诉模型“现在处于扩散过程的哪个位置”。

这里我们编码的思想参考了 Transformer 的原论文,但用于连续时间而不是序列位置。

对于每个时间步,我们会根据一定的规则把它映射到一个向量上。

向量中各个维度映射规则如下所示:

dim=D,则 half=D//2

  1. 频率公式

    $\omega_i=\exp(-\ln(\text{max\_period})\cdot\frac{i}{\text{half}})=\frac{1}{\text{max\_period}^{i/\text{half}}}$

    其中,$i \in [0, 1, \dots,\text{half}-1]$

    1
    2
    3
    4
    
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(0, half, dtype=torch.float32) / half
        ).to(timesteps.device)
    

代码中为什么要用 exp-log 形式而不是分数?

  1. 数值更加稳定,分数形式会有边界问题;exp-log 形式全程连续可导,无边界问题。

  2. 计算效率更高,幂运算通常比 exp 更慢。

  1. 维度映射通式

    对于输出向量的第 $d$ 维($d \in [0, 1, \dots, D-1]$):

    $\text { embedding }[d]=\left\{\begin{array}{ll} \cos \left(t \cdot \omega_{d}\right) & \text { if } d<\text { half } \\ \sin \left(t \cdot \omega_{d-\text { half }}\right) & \text { if } d \geq \text { half } \end{array}\right.$

    1
    2
    
        args = timesteps[:, None].float() * freqs[None, :]                  # [B, half]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)   # [B, dim]
    

dim=4 为例:

维度 计算公式
0 $\cos(t \cdot \omega_0)$
1 $\cos(t \cdot \omega_1)$
2 $\sin(t \cdot \omega_0)$
3 $\sin(t \cdot \omega_1)$

值得一提的是,当 dim 为奇数时,我们让最后一个维度不参与编码,即补零。

dim=7 为例:

维度 计算公式
0 $\cos(t \cdot \omega_0)$
1 $\cos(t \cdot \omega_1)$
2 $\cos(t \cdot \omega_2)$
3 $\sin(t \cdot \omega_0)$
4 $\sin(t \cdot \omega_1)$
5 $\sin(t \cdot \omega_2)$
6 0
1
2
3
    # 奇数维度时补零
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
patchify()unpatchify()

这两个函数实现了像素网络Token序列之间的转换。

具体细节没什么好说的,我们借此总结一些 Pytorch 的常见方法:

  • unflod() - 滑动窗口提取

    在指定维度上以固定步长提取滑动窗口

    1
    
    x.unfold(dimension, size, step)
    
    1
    
        x = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    

    在 H, W 维度上展开,窗口大小=步长= patch_size

    举个例子,把 28×28 图像切成 4×4 的小块,每块独立成维。

  • reshape() vs view()

    reshape() = view() + contiguous()

    • view() 速度更快,但要求内存连续(确定连续时推荐使用)

    • reshape() 自动处理非连续,必要时拷贝(不确定连续性时)

    view() 用于改变张量形状,不改变数据内容

    1
    2
    
        # [B, N, C*p*p]
        patches = x.view(B, -1, C * patch_size * patch_size)
    

    -1 让 Pytorch 自动计算该维度大小,避免手动算 N。

  • permute() - 维度重排

    按指定顺序重新排列张量维度。

    1
    2
    
        # [B, C, H//p, p, W//p, p] -> [B, H//p, W//p, C, p, p]
        x = x.permute(0, 2, 3, 1, 4, 5).contiguous()
    
  • contiguous() - 内存连续性保证

    确保张量在内存中连续存储,使 view() 能正常工作。

    • permute()unflod() 等操作会改变内存布局,导致张量不连续

    • view() 要求连续内存,否则报错

统计模型可训练参数量 count_parameters()

还是一样,我们来看看这个函数中 Pytorch 的相关方法与属性

  • numel() - 元素计数

    返回张量中元素的总个数(number of elements)

  • parameters() - 参数迭代器

    返回模型所有可学习参数的生成器,用于遍历

  • requires_grad - 梯度追踪标志

    标记张量是否需要计算梯度(是否参与训练)

    含义 场景
    True 需要梯度,反向传播时更新 可学习参数
    False 冻结,不更新 预训练模型冻结、推理模式
Beta 噪声调度 get_beta_schedule()

这个函数生成扩散过程中每一步的噪声强度 $\beta_t$ ,控制“每一步加多少噪声”

函数返回 betas[timesteps] 形状的一维张量,每个元素是 0~1 之间的浮点数

betas[t] = 第 t 步的噪声方差

这里我们着重看看余弦调度的实现:

余弦调度

余弦调度的目标是:

让信号衰减速度遵循余弦曲线的平方,前期慢,后期快,形成平滑过渡

  1. 定义累计信号比例

    定义从原始图像到第 $t$ 步保留的信号比例为:

    $$ \overline{\alpha}_t=\cos^2(\frac{t}{T}\cdot\frac{\pi}{2}) $$
  • $t=0$ :$\cos^2(0)=1$ ,完全保留信号

  • $t=T$:$\cos^2(1)=0$,完全变成噪声

  • 中间平缓过渡,曲线形状先缓后陡

  1. 引入偏移量防止数值问题

    为了防止 $t=0$ 时导数过大,我们引入偏移量 $s$(代码中取s=0.008):

    $$ \overline{\alpha}_t=\cos^2(\frac{\frac{t}{T}+s}{1+s}\cdot\frac{\pi}{2})/C $$

    其中 $C$ 是归一化常数,确保 $\overline{\alpha}_0=1

1
2
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
  1. 从累计信号反推单步噪声

    已知

    $$ \overline{\alpha}_t=\prod^t_{i=1}(1-\beta_i) $$

    这是扩散模型的核心,是 $x_0$ 无需迭代,可以直接得出任意 $x_t$ 的理论来源。具体推导见前文(待补充)

    则:

    $$ \beta_i=1-\frac{\overline{\alpha}_t}{\overline{\alpha}_{t-1}} $$

    直观理解一下,每一步的噪声强度 = 1- 这一步相对于上一步保留了多少额外信号

    1
    
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    
  2. 裁剪保护

    1
    
    return torch.clip(betas, 0.0001, 0.9999)
    

    防止极端值导致数值不稳定

diffusion.py

📁 diffusion.py 完整代码
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import torch
import torch.nn.functional as F
import math
from typing import Optional, Tuple
from tqdm import tqdm

from utils import get_beta_schedule, timestep_embedding
from config import DiffusionConfig, TrainConfig

class Diffusion:
    """
    DDPM 扩散过程:封装前向加噪、训练损失计算、反向采样
    """

    def __init__(
            self,
            config: DiffusionConfig,
            device: Optional[str] = None
    ):
        self.config = config
        self.device = device if device else config.device

        timesteps = config.timesteps
        beta_start = config.beta_start
        beta_end = config.beta_end
        schedule = config.schedule

        # 获取 beta 调度
        self.betas = get_beta_schedule(schedule, timesteps, beta_start, beta_end).to(self.device)

        # 计算 alpha 相关变量
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # 预计算平方根等中间量,避免重复计算
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)

        # 后验方差
        self.posterior_variance = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.sqrt_posterior_variance = torch.sqrt(self.posterior_variance)

        # 数值稳定性裁剪
        self.posterior_variance = torch.clip(self.posterior_variance, min=1e-20)

    def q_sample(
            self,
            x_start: torch.Tensor,
            t: torch.Tensor,
            noise: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        前向扩散:从 x_0 采样得到 x_t(重参数化技巧)
        """
        if noise is None:
            noise = torch.randn_like(x_start)

        # 提取对应时间步的系数
        sqrt_alpha_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)

        # x_t = √ᾱ_t * x_0 + √(1-ᾱ_t) * ε
        return sqrt_alpha_cumprod_t * x_start + sqrt_one_minus_alpha_cumprod_t * noise
    
    def predict_start_from_noise(
            self, 
            x_t: torch.Tensor,
            t: torch.Tensor,
            noise: torch.Tensor
    ) -> torch.Tensor:
        """
        从预测的噪声反推 x_0
        """
        sqrt_alpha_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)

        # x_0 = (x_t - √(1-ᾱ_t)*ε) / √ᾱ_t
        return (x_t - sqrt_one_minus_alpha_cumprod_t * noise) / sqrt_alpha_cumprod_t
    
    def q_posterior_mean_variance(
            self,
            x_start: torch.Tensor,
            x_t: torch.Tensor,
            t: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        计算真实后验分布 q(x_{t-1} | x_t, x_0) 的均值和方差
        """
        alpha_cumprod_prev_t = self.alphas_cumprod_prev[t].view(-1, 1, 1, 1)
        alpha_cumprod_t = self.alphas_cumprod[t].view(-1, 1, 1, 1)
        beta_t = self.betas[t].view(-1, 1, 1, 1)

        # 后验均值系数
        coef_x0 = torch.sqrt(alpha_cumprod_prev_t) * beta_t / (1 - alpha_cumprod_t)
        coef_xt = torch.sqrt(self.alphas[t].view(-1, 1, 1, 1)) * (1 - alpha_cumprod_prev_t) / (1 - alpha_cumprod_t)

        posterior_mean = coef_x0 * x_start + coef_xt * x_t
        posterior_variance = self.posterior_variance[t].view(-1, 1, 1, 1)

        return posterior_mean, posterior_variance
    
    def p_losses(
            self,
            model: torch.nn.Module,
            x_start: torch.Tensor,
            t: torch.Tensor,
            y: Optional[torch.Tensor] = None,
            noise: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        计算训练损失(预测噪声的 MSE)
        """
        if noise is None:
            noise = torch.randn_like(x_start)

        # 前向加噪得到 x_t
        x_t = self.q_sample(x_start, t, noise)

        # 模型预测噪声
        noise_pred = model(x_t, t, y)

        # MSE 损失
        loss = F.mse_loss(noise_pred, noise)

        return loss
    
    @torch.no_grad()
    def p_sample(
        self,
        model: torch.nn.Module,
        x_t: torch.Tensor,
        t: int,
        y: Optional[torch.Tensor] = None,
        cfg_scale: Optional[float] = None
    ) -> torch.Tensor:
        """
        单步去噪:从 x_t 采样 x_{t-1}
        """
        # 使用配置默认值或传入值
        cfg_scale = cfg_scale if cfg_scale is not None else self.config.cfg_scale

        batch_size = x_t.shape[0]
        t_batch = torch.full((batch_size, ), t, device=self.device, dtype=torch.long)

        # Classifier-free guidance
        if cfg_scale > 1.0 and y is not None:
            # 无条件预测
            noise_uncond = model(x_t, t_batch, None)
            # 有条件预测
            noise_cond = model(x_t, t_batch, y)
            # 外推
            noise_pred = noise_uncond + cfg_scale * (noise_cond - noise_uncond)
        else:
            noise_pred = model(x_t, t_batch, y)

        # 计算预测的 x_0
        x_0_pred = self.predict_start_from_noise(x_t, t_batch, noise_pred)

        # 使用配置中的裁剪范围(可扩展)
        x_0_pred = x_0_pred.clamp(-1.0, 1.0)

        # 后验均值
        alpha_cumprod_t = self.alphas_cumprod[t].item()
        alpha_cumprod_prev_t = self.alphas_cumprod_prev[t].item()
        beta_t = self.betas[t].item()

        coef_x0 = math.sqrt(alpha_cumprod_prev_t) * beta_t / (1 - alpha_cumprod_t)
        coef_xt = math.sqrt(self.alphas[t].item()) * (1 - alpha_cumprod_prev_t) / (1 - alpha_cumprod_t)

        mean = coef_x0 * x_0_pred + coef_xt * x_t

        # 采样方差
        if t > 0:
            variance = self.posterior_variance[t].item()
            noise = torch.randn_like(x_t)
            x_prev = mean + math.sqrt(variance) * noise
        else:
            x_prev = mean    # t=0 时不加噪声

        return x_prev
    
    @torch.no_grad()
    def sample(
            self,
            model: torch.nn.Module,
            shape: Tuple[int, int, int, int],
            y: Optional[torch.Tensor] = None,
            cfg_scale: Optional[float] = None,
            progress: bool = True
    ) -> torch.Tensor:
        """
        完整采样:从纯噪声生成图像
        """
        cfg_scale = cfg_scale if cfg_scale is not None else self.config.cfg_scale
        batch_size, C, H, W = shape

        # 从纯噪声开始
        x = torch.randn(shape, device=self.device)

        # 设置条件(随机类别如果未提供)
        if y is None:
            # 从配置获取类别数,避免硬编码
            num_classes = getattr(model.config, 'num_classes', 10)
            y = torch.randint(0, num_classes, (batch_size,), device=self.device)
        
        # 迭代去噪
        timesteps = range(self.config.timesteps -1, -1, -1)
        if progress: 
            timesteps = tqdm(timesteps, desc="Sampling")
        
        for t in timesteps:
            x = self.p_sample(model, x, t, y, cfg_scale)
        
        return x
    
    @torch.no_grad()
    def ddim_sample(
       self,
       model: torch.nn.Module,
       shape: Tuple[int, int, int, int],
       y: Optional[torch.Tensor] = None,
       steps: Optional[int] = None,
       eta: Optional[float] = None,
       cfg_scale: Optional[float] = None,
       progress: bool = True 
    ) -> torch.Tensor:
        """
        DDIM 加速采样(确定性,步数更少)
        """
        # 使用配置或传入值
        steps = steps if steps is not None else self.config.sample_steps
        eta = eta if eta is not None else 0.0   # 默认确定性
        cfg_scale = cfg_scale if cfg_scale is not None else self.config.cfg_scale

        batch_size = shape[0]

        # 均匀选取时间步
        c = self.config.timesteps // steps
        timesteps = list(range(0, self.config.timesteps, c))[:steps]
        timesteps = list(reversed(timesteps)) + [0]

        x = torch.randn(shape, device=self.device)
        if y is None:
            num_classes = getattr(model.config, 'num_classes', 10)
            y = torch.randint(0, num_classes, (batch_size,), device=self.device)

        iterator = tqdm(timesteps[:-1], desc="DDIM Sampling") if progress else timesteps[:-1]

        for i, t in enumerate(iterator):
            t_batch = torch.full((batch_size,), t, device=self.device, dtype=torch.long)

            # CFG
            if cfg_scale > 1.0:
                noise_uncond = model(x, t_batch, None)
                noise_cond = model(x, t_batch, y)
                noise_pred = noise_uncond + cfg_scale * (noise_cond - noise_uncond)
            else:
                noise_pred = model(x, t_batch, y)
            
            # 预测 x_0
            alpha_cumprod_t = self.alphas_cumprod[t]

            if i + 1 < len(timesteps) - 1:  # 还有下一步
                alpha_cumprod_prev = self.alphas_cumprod[timesteps[i+1]]
            else:
                alpha_cumprod_prev = torch.tensor(1.0, device=self.device) 

            x_0_pred = (x - torch.sqrt(1 - alpha_cumprod_t) * noise_pred) / torch.sqrt(alpha_cumprod_t)
            x_0_pred = x_0_pred.clamp(-1.0, 1.0)

            # DDIM 方向
            sigma_t = eta * torch.sqrt((1 - alpha_cumprod_prev) / (1 - alpha_cumprod_t) * (1 - alpha_cumprod_t / alpha_cumprod_prev))

            dir_xt = torch.sqrt(1 - alpha_cumprod_prev - sigma_t**2) * noise_pred

            x = torch.sqrt(alpha_cumprod_prev) * x_0_pred + dir_xt            
            
            if eta > 0 and i < len(timesteps) - 1:
                x = x + sigma_t * torch.randn_like(x)
            
        return x

参考