CogView3 & CogView-3Plus 微调代码源码解析(三)
.\cogview3-finetune\sat\sgm\modules\diffusionmodules\guiders.py
# 导入 logging 模块,用于记录日志信息
import logging
# 从 abc 模块导入 ABC 类和 abstractmethod 装饰器,用于定义抽象基类和抽象方法
from abc import ABC, abstractmethod
# 导入类型注解,方便在函数签名中定义复杂数据结构
from typing import Dict, List, Optional, Tuple, Union
# 从 functools 模块导入 partial 函数,用于部分应用函数
from functools import partial
# 导入数学模块,提供数学函数
import math
# 导入 PyTorch 库,提供张量计算功能
import torch
# 从 einops 模块导入 rearrange 和 repeat 函数,用于张量重排和重复
from einops import rearrange, repeat
# 从上层模块导入工具函数,提供一些默认值和实例化配置的功能
from ...util import append_dims, default, instantiate_from_config
# 定义一个抽象基类 Guider,继承自 ABC
class Guider(ABC):
# 定义一个抽象方法 __call__,接受一个张量和一个浮点数,返回一个张量
@abstractmethod
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
pass
# 定义准备输入的方法,接受多个参数并返回一个元组
def prepare_inputs(
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
) -> Tuple[torch.Tensor, float, Dict]:
pass
# 定义一个类 VanillaCFG,表示基本的条件生成模型
class VanillaCFG:
"""
implements parallelized CFG
"""
# 初始化方法,接受比例和动态阈值配置
def __init__(self, scale, dyn_thresh_config=None):
# 定义一个 lambda 函数,根据 sigma 返回 scale,保持独立于步数
scale_schedule = lambda scale, sigma: scale # independent of step
# 使用 partial 固定 scale 参数,创建 scale_schedule 方法
self.scale_schedule = partial(scale_schedule, scale)
# 实例化动态阈值对象,如果没有提供配置则使用默认配置
self.dyn_thresh = instantiate_from_config(
default(
dyn_thresh_config,
{
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
},
)
)
# 定义 __call__ 方法,使该类可以被调用,接受多个参数
def __call__(self, x, sigma, step = None, num_steps = None, **kwargs):
# 将输入张量 x 拆分为两个部分 x_u 和 x_c
x_u, x_c = x.chunk(2)
# 根据 sigma 计算 scale_value
scale_value = self.scale_schedule(sigma)
# 使用动态阈值处理函数进行预测,返回预测结果
x_pred = self.dyn_thresh(x_u, x_c, scale_value, step=step, num_steps=num_steps)
return x_pred
# 定义准备输入的方法,接受多个参数并返回一个元组
def prepare_inputs(self, x, s, c, uc):
# 初始化输出字典
c_out = dict()
# 遍历条件字典 c 的键
for k in c:
# 如果键是特定值,则将 uc 和 c 中的对应张量拼接
if k in ["vector", "crossattn", "concat"]:
c_out[k] = torch.cat((uc[k], c[k]), 0)
# 否则确保两个字典中对应的值相等,并直接赋值
else:
assert c[k] == uc[k]
c_out[k] = c[k]
# 返回拼接后的张量和条件字典
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
# 定义一个类 IdentityGuider,实现一个恒等引导器
class IdentityGuider:
# 定义 __call__ 方法,直接返回输入张量
def __call__(self, x, sigma, **kwargs):
return x
# 定义准备输入的方法,返回输入和条件字典
def prepare_inputs(self, x, s, c, uc):
# 初始化输出字典
c_out = dict()
# 遍历条件字典 c 的键
for k in c:
# 直接将条件字典 c 的值赋给输出字典
c_out[k] = c[k]
# 返回输入张量和条件字典
return x, s, c_out
# 定义一个类 LinearPredictionGuider,继承自 Guider
class LinearPredictionGuider(Guider):
# 初始化方法,接受多个参数
def __init__(
self,
max_scale: float,
num_frames: int,
min_scale: float = 1.0,
additional_cond_keys: Optional[Union[List[str], str]] = None,
):
# 初始化最小和最大比例
self.min_scale = min_scale
self.max_scale = max_scale
# 计算比例的线性变化,生成 num_frames 个值
self.num_frames = num_frames
self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
# 确保 additional_cond_keys 是一个列表,如果是字符串则转换为列表
additional_cond_keys = default(additional_cond_keys, [])
if isinstance(additional_cond_keys, str):
additional_cond_keys = [additional_cond_keys]
# 保存附加条件键
self.additional_cond_keys = additional_cond_keys
# 定义可调用对象的方法,接收输入张量 x 和 sigma,以及其他参数 kwargs,返回一个张量
def __call__(self, x: torch.Tensor, sigma: torch.Tensor, **kwargs) -> torch.Tensor:
# 将输入张量 x 拆分为两部分:x_u 和 x_c
x_u, x_c = x.chunk(2)
# 重排 x_u 的维度,使其形状为 (批量大小 b, 帧数 t, ...),t 由 num_frames 指定
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
# 重排 x_c 的维度,使其形状为 (批量大小 b, 帧数 t, ...),t 由 num_frames 指定
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
# 复制 scale 张量的维度,使其形状为 (批量大小 b, 帧数 t)
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
# 将 scale 的维度扩展到与 x_u 的维度一致,并移动到 x_u 的设备上
scale = append_dims(scale, x_u.ndim).to(x_u.device)
# 将 scale 转换为与 x_u 相同的数据类型
scale = scale.to(x_u.dtype)
# 返回经过计算的结果,重排为 (批量大小 b * 帧数 t, ...)
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
# 定义准备输入的函数,接收输入张量 x 和 s,以及条件字典 c 和 uc,返回一个元组
def prepare_inputs(
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
# 初始化一个空字典 c_out 用于存放处理后的条件
c_out = dict()
# 遍历条件字典 c 的每一个键 k
for k in c:
# 如果 k 是指定的条件键之一,进行拼接
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
# 将 uc[k] 和 c[k] 沿第0维拼接,并存入 c_out
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
# 确保 c[k] 与 uc[k] 相等
assert c[k] == uc[k]
# 将 c[k] 直接存入 c_out
c_out[k] = c[k]
# 返回拼接后的 x 和 s 以及处理后的条件字典 c_out
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
.\cogview3-finetune\sat\sgm\modules\diffusionmodules\loss.py
# 导入所需的标准库和类型提示
import os
import copy
from typing import List, Optional, Union
# 导入 NumPy 和 PyTorch 库
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# 导入 OmegaConf 中的 ListConfig
from omegaconf import ListConfig
# 从自定义模块中导入所需的函数和类
from ...util import append_dims, instantiate_from_config
from ...modules.autoencoding.lpips.loss.lpips import LPIPS
from ...modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from ...util import get_obj_from_str, default
from ...modules.diffusionmodules.discretizer import generate_roughly_equally_spaced_steps, sub_generate_roughly_equally_spaced_steps
# 定义标准扩散损失类,继承自 nn.Module
class StandardDiffusionLoss(nn.Module):
# 初始化方法,设置损失类型和噪声级别等参数
def __init__(
self,
sigma_sampler_config,
type="l2",
offset_noise_level=0.0,
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
):
super().__init__()
# 确保损失类型有效
assert type in ["l2", "l1", "lpips"]
# 根据配置实例化 sigma 采样器
self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
# 保存损失类型和噪声级别
self.type = type
self.offset_noise_level = offset_noise_level
# 如果损失类型为 lpips,则初始化 lpips 模块
if type == "lpips":
self.lpips = LPIPS().eval()
# 如果没有提供 batch2model_keys,则设置为空列表
if not batch2model_keys:
batch2model_keys = []
# 如果 batch2model_keys 是字符串,则转为列表
if isinstance(batch2model_keys, str):
batch2model_keys = [batch2model_keys]
# 将 batch2model_keys 转为集合以便于后续处理
self.batch2model_keys = set(batch2model_keys)
# 定义调用方法,计算损失
def __call__(self, network, denoiser, conditioner, input, batch):
# 使用条件器处理输入批次
cond = conditioner(batch)
# 从批次中提取附加模型输入
additional_model_inputs = {
key: batch[key] for key in self.batch2model_keys.intersection(batch)
}
# 生成 sigma 值
sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
# 生成与输入相同形状的随机噪声
noise = torch.randn_like(input)
# 如果设置了噪声级别,调整噪声
if self.offset_noise_level > 0.0:
noise = noise + append_dims(
torch.randn(input.shape[0]).to(input.device), input.ndim
) * self.offset_noise_level
# 确保噪声数据类型与输入一致
noise = noise.to(input.dtype)
# 将输入与噪声和 sigma 结合,生成有噪声的输入
noised_input = input.float() + noise * append_dims(sigmas, input.ndim)
# 使用去噪网络处理有噪声的输入
model_output = denoiser(
network, noised_input, sigmas, cond, **additional_model_inputs
)
# 将去噪网络的权重调整为与输入相同的维度
w = append_dims(denoiser.w(sigmas), input.ndim)
# 返回损失值
return self.get_loss(model_output, input, w)
# 定义计算损失的方法
def get_loss(self, model_output, target, w):
# 根据损失类型计算 l2 损失
if self.type == "l2":
return torch.mean(
(w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
)
# 根据损失类型计算 l1 损失
elif self.type == "l1":
return torch.mean(
(w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
)
# 根据损失类型计算 lpips 损失
elif self.type == "lpips":
loss = self.lpips(model_output, target).reshape(-1)
return loss
# 定义线性中继扩散损失类,继承自 StandardDiffusionLoss
class LinearRelayDiffusionLoss(StandardDiffusionLoss):
# 初始化方法,设置相关参数
def __init__(
self,
sigma_sampler_config,
type="l2",
offset_noise_level=0.0,
partial_num_steps=500,
blurring_schedule='linear',
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
):
# 调用父类构造函数,初始化基本参数
super().__init__(
sigma_sampler_config, # sigma 采样器的配置
type=type, # 类型参数
offset_noise_level=offset_noise_level, # 偏移噪声水平
batch2model_keys=batch2model_keys, # 批次到模型的键映射
)
# 设置模糊调度参数
self.blurring_schedule = blurring_schedule
# 设置部分步骤数量
self.partial_num_steps = partial_num_steps
def __call__(self, network, denoiser, conditioner, input, batch):
# 使用调节器处理批次数据,生成条件
cond = conditioner(batch)
# 生成额外的模型输入,筛选出与模型键对应的批次数据
additional_model_inputs = {
key: batch[key] for key in self.batch2model_keys.intersection(batch)
}
# 从批次中获取低分辨率输入
lr_input = batch["lr_input"]
# 生成随机整数,用于选择部分步骤
rand = torch.randint(0, self.partial_num_steps, (input.shape[0],))
# 从 sigma 采样器生成 sigma 值,并转换为输入数据类型和设备
sigmas = self.sigma_sampler(input.shape[0], rand).to(input.dtype).to(input.device)
# 生成与输入形状相同的随机噪声
noise = torch.randn_like(input)
# 如果偏移噪声水平大于0,则添加额外噪声
if self.offset_noise_level > 0.0:
# 生成额外随机噪声并调整其维度,乘以偏移噪声水平
noise = noise + append_dims(
torch.randn(input.shape[0]).to(input.device), input.ndim
) * self.offset_noise_level
# 转换噪声为输入数据类型
noise = noise.to(input.dtype)
# 调整 rand 的维度并转换为输入数据类型和设备
rand = append_dims(rand, input.ndim).to(input.dtype).to(input.device)
# 根据模糊调度的不同方式计算模糊输入
if self.blurring_schedule == 'linear':
# 线性模糊处理
blurred_input = input * (1 - rand / self.partial_num_steps) + lr_input * (rand / self.partial_num_steps)
elif self.blurring_schedule == 'sigma':
# 使用 sigma 最大值进行模糊处理
max_sigmas = self.sigma_sampler(input.shape[0], torch.ones(input.shape[0])*self.partial_num_steps).to(input.dtype).to(input.device)
blurred_input = input * (1 - sigmas / max_sigmas) + lr_input * (sigmas / max_sigmas)
elif self.blurring_schedule == 'exp':
# 指数模糊处理
rand_blurring = (1 - torch.exp(-(torch.sin((rand+1) / self.partial_num_steps * torch.pi / 2)**4))) / (1 - torch.exp(-torch.ones_like(rand)))
blurred_input = input * (1 - rand_blurring) + lr_input * rand_blurring
else:
# 如果模糊调度不被支持,抛出未实现错误
raise NotImplementedError
# 将噪声添加到模糊输入中
noised_input = blurred_input + noise * append_dims(sigmas, input.ndim)
# 调用去噪声器处理模糊输入,获取模型输出
model_output = denoiser(
network, noised_input, sigmas, cond, **additional_model_inputs
)
# 调整去噪声器权重的维度
w = append_dims(denoiser.w(sigmas), input.ndim)
# 返回模型输出的损失值
return self.get_loss(model_output, input, w)
# 定义一个名为 ZeroSNRDiffusionLoss 的类,继承自 StandardDiffusionLoss
class ZeroSNRDiffusionLoss(StandardDiffusionLoss):
# 重载调用方法,接受网络、去噪器、条件、输入和批次作为参数
def __call__(self, network, denoiser, conditioner, input, batch):
# 使用条件生成器处理批次,得到条件变量
cond = conditioner(batch)
# 从批次中提取与模型键相交的额外输入
additional_model_inputs = {
key: batch[key] for key in self.batch2model_keys.intersection(batch)
}
# 生成累积的 alpha 值并获取索引
alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True)
# 将 alpha 值移动到输入的设备上
alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device)
# 将索引移动到输入的数据类型和设备上
idx = idx.to(input.dtype).to(input.device)
# 将索引添加到额外模型输入中
additional_model_inputs['idx'] = idx
# 生成与输入形状相同的随机噪声
noise = torch.randn_like(input)
# 如果偏移噪声水平大于零,则添加额外噪声
if self.offset_noise_level > 0.0:
noise = noise + append_dims(
# 生成随机噪声并调整维度,乘以偏移噪声水平
torch.randn(input.shape[0]).to(input.device), input.ndim
) * self.offset_noise_level
# 计算加入噪声的输入
noised_input = input.float() * append_dims(alphas_cumprod_sqrt, input.ndim) + noise * append_dims((1-alphas_cumprod_sqrt**2)**0.5, input.ndim)
# 使用去噪器处理带噪声的输入
model_output = denoiser(
network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs
)
# 计算 v-pred 权重
w = append_dims(1/(1-alphas_cumprod_sqrt**2), input.ndim)
# 返回损失值
return self.get_loss(model_output, input, w)
# 定义一个获取损失的函数
def get_loss(self, model_output, target, w):
# 如果损失类型为 L2,计算 L2 损失
if self.type == "l2":
return torch.mean(
# 计算每个样本的 L2 损失并调整维度
(w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
)
# 如果损失类型为 L1,计算 L1 损失
elif self.type == "l1":
return torch.mean(
# 计算每个样本的 L1 损失并调整维度
(w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
)
# 如果损失类型为 LPIPS,计算 LPIPS 损失
elif self.type == "lpips":
loss = self.lpips(model_output, target).reshape(-1)
return loss
.\cogview3-finetune\sat\sgm\modules\diffusionmodules\model.py
# pytorch_diffusion + derived encoder decoder
# 导入数学库
import math
# 导入类型注解相关
from typing import Any, Callable, Optional
# 导入 numpy 库
import numpy as np
# 导入 pytorch 库
import torch
# 导入 pytorch 神经网络模块
import torch.nn as nn
# 导入 rearrange 函数以处理张量重排列
from einops import rearrange
# 导入版本管理库
from packaging import version
# 尝试导入 xformers 模块
try:
import xformers
import xformers.ops
# 如果成功导入,设置标志为 True
XFORMERS_IS_AVAILABLE = True
except:
# 如果导入失败,设置标志为 False,并打印提示信息
XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...")
# 从其他模块导入 LinearAttention 和 MemoryEfficientCrossAttention
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
def get_timestep_embedding(timesteps, embedding_dim):
"""
此函数与 Denoising Diffusion Probabilistic Models 中的实现相匹配:
来自 Fairseq。
构建正弦嵌入。
此实现与 tensor2tensor 中的实现相匹配,但与 "Attention Is All You Need" 第 3.5 节中的描述略有不同。
"""
# 确保时间步长是一维的
assert len(timesteps.shape) == 1
# 计算嵌入维度的一半
half_dim = embedding_dim // 2
# 计算嵌入因子的对数
emb = math.log(10000) / (half_dim - 1)
# 计算并生成指数衰减的嵌入
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
# 将嵌入移动到与时间步相同的设备上
emb = emb.to(device=timesteps.device)
# 扩展时间步并与嵌入相乘
emb = timesteps.float()[:, None] * emb[None, :]
# 将正弦和余弦嵌入拼接在一起
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# 如果嵌入维度是奇数,则进行零填充
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
# 返回最终的嵌入
return emb
def nonlinearity(x):
# 使用 swish 激活函数
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
# 返回一个 GroupNorm 归一化层
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
# 初始化 Upsample 类
super().__init__()
# 记录是否使用卷积
self.with_conv = with_conv
# 如果使用卷积,则定义卷积层
if self.with_conv:
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
# 使用最近邻插值将输入张量上采样
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
# 如果使用卷积,则应用卷积层
if self.with_conv:
x = self.conv(x)
# 返回处理后的张量
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
# 初始化 Downsample 类
super().__init__()
# 记录是否使用卷积
self.with_conv = with_conv
# 如果使用卷积,则定义卷积层
if self.with_conv:
# 因为 pytorch 卷积不支持不对称填充,需手动处理
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x):
# 如果使用卷积,先进行填充再应用卷积层
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
# 否则使用平均池化进行下采样
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
# 返回处理后的张量
return x
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
):
# 调用父类的初始化方法
super().__init__()
# 保存输入通道数
self.in_channels = in_channels
# 如果未指定输出通道数,则设置为输入通道数
out_channels = in_channels if out_channels is None else out_channels
# 保存输出通道数
self.out_channels = out_channels
# 保存是否使用卷积捷径的标志
self.use_conv_shortcut = conv_shortcut
# 初始化输入通道数的归一化层
self.norm1 = Normalize(in_channels)
# 定义第一层卷积,输入输出通道及卷积核参数
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
# 如果有时间嵌入通道,则定义时间嵌入投影层
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
# 初始化输出通道数的归一化层
self.norm2 = Normalize(out_channels)
# 定义 dropout 层
self.dropout = torch.nn.Dropout(dropout)
# 定义第二层卷积,输入输出通道及卷积核参数
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
# 如果输入和输出通道数不相同
if self.in_channels != self.out_channels:
# 如果使用卷积捷径,则定义卷积捷径层
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
# 否则定义 1x1 卷积捷径层
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
# 前向传播函数
def forward(self, x, temb):
# 将输入赋值给 h 变量
h = x
# 对 h 进行归一化
h = self.norm1(h)
# 应用非线性激活函数
h = nonlinearity(h)
# 通过第一层卷积处理 h
h = self.conv1(h)
# 如果时间嵌入不为 None
if temb is not None:
# 将时间嵌入通过非线性激活函数处理后投影到输出通道,并与 h 相加
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
# 对 h 进行第二次归一化
h = self.norm2(h)
# 应用非线性激活函数
h = nonlinearity(h)
# 通过 dropout 层处理 h
h = self.dropout(h)
# 通过第二层卷积处理 h
h = self.conv2(h)
# 如果输入和输出通道数不相同
if self.in_channels != self.out_channels:
# 如果使用卷积捷径,则通过卷积捷径层处理 x
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
# 否则通过 1x1 卷积捷径层处理 x
else:
x = self.nin_shortcut(x)
# 返回 x 和 h 的相加结果
return x + h
# 定义 LinAttnBlock 类,继承自 LinearAttention
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage""" # 文档字符串,说明该类用于匹配 AttnBlock 的使用方式
# 初始化方法,接受输入通道数
def __init__(self, in_channels):
# 调用父类的初始化方法,设置维度和头数
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
# 定义 AttnBlock 类,继承自 nn.Module
class AttnBlock(nn.Module):
# 初始化方法,接受输入通道数
def __init__(self, in_channels):
# 调用父类的初始化方法
super().__init__()
# 保存输入通道数
self.in_channels = in_channels
# 初始化归一化层
self.norm = Normalize(in_channels)
# 初始化查询卷积层
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
# 初始化键卷积层
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
# 初始化值卷积层
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
# 初始化输出投影卷积层
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
# 定义注意力计算方法
def attention(self, h_: torch.Tensor) -> torch.Tensor:
# 对输入进行归一化
h_ = self.norm(h_)
# 计算查询、键和值
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# 获取查询的形状参数
b, c, h, w = q.shape
# 重新排列查询、键和值的形状
q, k, v = map(
lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
)
# 计算缩放的点积注意力
h_ = torch.nn.functional.scaled_dot_product_attention(
q, k, v
) # scale is dim ** -0.5 per default
# 计算注意力
# 返回重新排列后的注意力结果
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
# 定义前向传播方法
def forward(self, x, **kwargs):
# 将输入赋值给 h_
h_ = x
# 计算注意力
h_ = self.attention(h_)
# 应用输出投影
h_ = self.proj_out(h_)
# 返回输入与注意力结果的和
return x + h_
# 定义 MemoryEfficientAttnBlock 类,继承自 nn.Module
class MemoryEfficientAttnBlock(nn.Module):
"""
Uses xformers efficient implementation,
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
Note: this is a single-head self-attention operation
""" # 文档字符串,说明该类使用 xformers 高效实现的单头自注意力
# 初始化方法,接受输入通道数
def __init__(self, in_channels):
# 调用父类的初始化方法
super().__init__()
# 保存输入通道数
self.in_channels = in_channels
# 初始化归一化层
self.norm = Normalize(in_channels)
# 初始化查询卷积层
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
# 初始化键卷积层
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
# 初始化值卷积层
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
# 初始化输出投影卷积层
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
# 初始化注意力操作,类型为可选的任意类型
self.attention_op: Optional[Any] = None
# 定义注意力机制的函数,输入为一个张量,输出也是一个张量
def attention(self, h_: torch.Tensor) -> torch.Tensor:
# 先对输入进行归一化处理
h_ = self.norm(h_)
# 通过线性变换生成查询张量
q = self.q(h_)
# 通过线性变换生成键张量
k = self.k(h_)
# 通过线性变换生成值张量
v = self.v(h_)
# 计算注意力
# 获取查询张量的形状信息
B, C, H, W = q.shape
# 调整张量形状,将其从四维转为二维
q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
# 对查询、键、值进行维度调整以便计算注意力
q, k, v = map(
lambda t: t.unsqueeze(3) # 在最后增加一个维度
.reshape(B, t.shape[1], 1, C) # 调整形状
.permute(0, 2, 1, 3) # 变换维度顺序
.reshape(B * 1, t.shape[1], C) # 重新调整形状
.contiguous(), # 保证内存连续性
(q, k, v),
)
# 使用内存高效的注意力操作
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op
)
# 调整输出张量的形状
out = (
out.unsqueeze(0) # 增加一个维度
.reshape(B, 1, out.shape[1], C) # 调整形状
.permute(0, 2, 1, 3) # 变换维度顺序
.reshape(B, out.shape[1], C) # 重新调整形状
)
# 将输出张量的形状恢复为原来的格式
return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
# 定义前向传播函数
def forward(self, x, **kwargs):
# 输入数据赋值给 h_
h_ = x
# 通过注意力机制处理 h_
h_ = self.attention(h_)
# 通过输出投影处理 h_
h_ = self.proj_out(h_)
# 返回输入和处理后的 h_ 的和
return x + h_
# 定义一个内存高效的交叉注意力包装类,继承自 MemoryEfficientCrossAttention
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
# 前向传播方法,接受输入张量和可选的上下文、掩码
def forward(self, x, context=None, mask=None, **unused_kwargs):
# 解包输入张量的维度:批量大小、通道数、高度和宽度
b, c, h, w = x.shape
# 重新排列输入张量的维度,将 (b, c, h, w) 转换为 (b, h*w, c)
x = rearrange(x, "b c h w -> b (h w) c")
# 调用父类的 forward 方法,处理重新排列后的输入
out = super().forward(x, context=context, mask=mask)
# 将输出张量的维度重新排列回 (b, c, h, w)
out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
# 返回输入与输出的和,进行残差连接
return x + out
# 定义一个生成注意力模块的函数
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
# 检查传入的注意力类型是否在支持的类型列表中
assert attn_type in [
"vanilla",
"vanilla-xformers",
"memory-efficient-cross-attn",
"linear",
"none",
], f"attn_type {attn_type} unknown"
# 检查 PyTorch 版本,并且如果类型不是 "none",则验证是否可用 xformers
if (
version.parse(torch.__version__) < version.parse("2.0.0")
and attn_type != "none"
):
assert XFORMERS_IS_AVAILABLE, (
f"We do not support vanilla attention in {torch.__version__} anymore, "
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
# 将注意力类型设置为 "vanilla-xformers"
attn_type = "vanilla-xformers"
# 根据注意力类型生成相应的注意力块
if attn_type == "vanilla":
# 验证注意力参数不为 None
assert attn_kwargs is None
# 返回标准的注意力块
return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers":
# 返回内存高效的注意力块
return MemoryEfficientAttnBlock(in_channels)
elif attn_type == "memory-efficient-cross-attn":
# 设置查询维度为输入通道数
attn_kwargs["query_dim"] = in_channels
# 返回内存高效的交叉注意力包装类
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
elif attn_type == "none":
# 返回一个身份映射层,不改变输入
return nn.Identity(in_channels)
else:
# 返回线性注意力块
return LinAttnBlock(in_channels)
# 定义一个模型类,继承自 nn.Module
class Model(nn.Module):
# 初始化方法,接受多个参数进行模型构建
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
use_timestep=True,
use_linear_attn=False,
attn_type="vanilla",
# 定义前向传播方法,接受输入 x、时间步 t 和上下文 context
def forward(self, x, t=None, context=None):
# 确保输入 x 的高度和宽度与设定的分辨率相等(被注释掉)
# assert x.shape[2] == x.shape[3] == self.resolution
# 如果上下文不为 None,沿通道维度连接输入 x 和上下文
if context is not None:
# 假设上下文对齐,沿通道轴拼接
x = torch.cat((x, context), dim=1)
# 如果使用时间步,进行时间步嵌入
if self.use_timestep:
# 确保时间步 t 不为 None
assert t is not None
# 获取时间步嵌入
temb = get_timestep_embedding(t, self.ch)
# 通过第一层密集层处理时间步嵌入
temb = self.temb.dense[0](temb)
# 应用非线性变换
temb = nonlinearity(temb)
# 通过第二层密集层处理
temb = self.temb.dense[1](temb)
else:
# 如果不使用时间步,设置时间步嵌入为 None
temb = None
# 下采样
hs = [self.conv_in(x)] # 初始卷积层的输出
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
# 通过当前下采样层和时间步嵌入处理前一层输出
h = self.down[i_level].block[i_block](hs[-1], temb)
# 如果存在注意力层,则对输出进行注意力处理
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
# 将处理后的输出添加到列表
hs.append(h)
# 如果不是最后一层分辨率,进行下采样
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# 中间处理
h = hs[-1] # 获取最后一层的输出
h = self.mid.block_1(h, temb) # 通过中间块处理
h = self.mid.attn_1(h) # 通过中间注意力层处理
h = self.mid.block_2(h, temb) # 再次通过中间块处理
# 上采样
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
# 拼接上层输出和当前层的输出,然后通过上采样块处理
h = self.up[i_level].block[i_block](
torch.cat([h, hs.pop()], dim=1), temb
)
# 如果存在注意力层,则对输出进行注意力处理
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
# 如果不是第一层分辨率,进行上采样
if i_level != 0:
h = self.up[i_level].upsample(h)
# 结束处理
h = self.norm_out(h) # 最后的归一化处理
h = nonlinearity(h) # 应用非线性变换
h = self.conv_out(h) # 通过输出卷积层处理
return h # 返回最终输出
# 获取最后一层的卷积权重
def get_last_layer(self):
return self.conv_out.weight # 返回输出卷积层的权重
# 定义一个编码器类,继承自 nn.Module
class Encoder(nn.Module):
# 初始化方法,接收多个参数用于配置编码器
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type="vanilla",
mid_attn=True,
**ignore_kwargs,
):
# 调用父类构造方法
super().__init__()
# 如果使用线性注意力,设置注意力类型为线性
if use_linear_attn:
attn_type = "linear"
# 保存输入参数以供后续使用
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.attn_resolutions = attn_resolutions
self.mid_attn = mid_attn
# 下采样
# 定义输入卷积层
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
# 当前分辨率初始化
curr_res = resolution
# 定义输入通道的倍率
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
# 初始化下采样模块列表
self.down = nn.ModuleList()
# 遍历每个分辨率层级
for i_level in range(self.num_resolutions):
# 初始化块和注意力模块列表
block = nn.ModuleList()
attn = nn.ModuleList()
# 输入和输出通道数计算
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
# 遍历每个残差块
for i_block in range(self.num_res_blocks):
# 添加残差块到块列表中
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
# 更新输入通道数为当前块的输出通道数
block_in = block_out
# 如果当前分辨率在注意力分辨率列表中,添加注意力模块
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
# 创建下采样模块
down = nn.Module()
down.block = block
down.attn = attn
# 如果不是最后一个分辨率,添加下采样层
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
# 更新当前分辨率为一半
curr_res = curr_res // 2
# 将下采样模块添加到列表中
self.down.append(down)
# 中间层
self.mid = nn.Module()
# 添加第一个残差块
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# 如果使用中间注意力,添加注意力模块
if mid_attn:
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
# 添加第二个残差块
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# 结束层
# 定义归一化层
self.norm_out = Normalize(block_in)
# 定义输出卷积层,根据是否双 z 通道设置输出通道数
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1,
)
# 定义前向传播方法,接受输入数据 x
def forward(self, x):
# 时间步嵌入初始化为 None
temb = None
# 下采样过程
# 对输入 x 进行卷积操作,生成初始特征图 hs
hs = [self.conv_in(x)]
# 遍历每个分辨率层
for i_level in range(self.num_resolutions):
# 遍历当前分辨率层中的每个残差块
for i_block in range(self.num_res_blocks):
# 使用当前层的残差块处理上一个层的输出和时间步嵌入
h = self.down[i_level].block[i_block](hs[-1], temb)
# 如果当前层有注意力机制,则应用注意力
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
# 将当前层的输出添加到特征图列表中
hs.append(h)
# 如果当前层不是最后一个分辨率层,则进行下采样
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# 中间处理阶段
h = hs[-1] # 获取最后一层的输出
# 通过中间块1处理输入
h = self.mid.block_1(h, temb)
# 如果中间层有注意力机制,则应用注意力
if self.mid_attn:
h = self.mid.attn_1(h)
# 通过中间块2处理输出
h = self.mid.block_2(h, temb)
# 最终处理阶段
h = self.norm_out(h) # 应用输出归一化
h = nonlinearity(h) # 应用非线性激活函数
h = self.conv_out(h) # 通过输出卷积生成最终结果
return h # 返回最终输出
# 定义一个解码器类,继承自 PyTorch 的 nn.Module
class Decoder(nn.Module):
# 初始化方法,定义解码器的参数
def __init__(
self,
*,
ch, # 输入通道数
out_ch, # 输出通道数
ch_mult=(1, 2, 4, 8), # 通道数的倍增因子
num_res_blocks, # 残差块的数量
attn_resolutions, # 注意力机制应用的分辨率
dropout=0.0, # dropout 比例,默认值为 0
resamp_with_conv=True, # 是否使用卷积进行上采样
in_channels, # 输入的通道数
resolution, # 输入的分辨率
z_channels, # 潜在变量的通道数
give_pre_end=False, # 是否在前面给予额外的结束标志
tanh_out=False, # 输出是否经过 tanh 激活
use_linear_attn=False, # 是否使用线性注意力机制
attn_type="vanilla", # 注意力类型,默认为“vanilla”
mid_attn=True, # 是否在中间层使用注意力
**ignorekwargs, # 其他忽略的参数,采用关键字参数形式
):
# 初始化父类
super().__init__()
# 如果使用线性注意力机制,设置注意力类型为线性
if use_linear_attn:
attn_type = "linear"
# 设置通道数
self.ch = ch
# 初始化时间嵌入通道数为0
self.temb_ch = 0
# 计算分辨率数量
self.num_resolutions = len(ch_mult)
# 设置残差块数量
self.num_res_blocks = num_res_blocks
# 设置输入分辨率
self.resolution = resolution
# 设置输入通道数
self.in_channels = in_channels
# 设置是否给出前置结束标志
self.give_pre_end = give_pre_end
# 设置激活函数输出
self.tanh_out = tanh_out
# 设置注意力分辨率
self.attn_resolutions = attn_resolutions
# 设置中间注意力
self.mid_attn = mid_attn
# 计算输入通道倍数、块输入通道和当前最低分辨率
in_ch_mult = (1,) + tuple(ch_mult)
# 计算当前块的输入通道数
block_in = ch * ch_mult[self.num_resolutions - 1]
# 计算当前分辨率
curr_res = resolution // 2 ** (self.num_resolutions - 1)
# 设置潜在变量的形状
self.z_shape = (1, z_channels, curr_res, curr_res)
# print(
# "Working with z of shape {} = {} dimensions.".format(
# self.z_shape, np.prod(self.z_shape)
# )
# )
# 创建注意力和残差块类
make_attn_cls = self._make_attn()
make_resblock_cls = self._make_resblock()
make_conv_cls = self._make_conv()
# 将潜在变量映射到块输入通道
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# 中间层
self.mid = nn.Module()
# 创建第一个残差块
self.mid.block_1 = make_resblock_cls(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# 如果启用中间注意力,创建注意力层
if mid_attn:
self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
# 创建第二个残差块
self.mid.block_2 = make_resblock_cls(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# 上采样层
self.up = nn.ModuleList()
# 从高到低遍历每个分辨率级别
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList() # 残差块列表
attn = nn.ModuleList() # 注意力层列表
# 计算当前块的输出通道数
block_out = ch * ch_mult[i_level]
# 创建每个残差块
for i_block in range(self.num_res_blocks + 1):
block.append(
make_resblock_cls(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
# 更新块输入通道
block_in = block_out
# 如果当前分辨率在注意力分辨率中,添加注意力层
if curr_res in attn_resolutions:
attn.append(make_attn_cls(block_in, attn_type=attn_type))
up = nn.Module() # 上采样模块
up.block = block # 添加残差块
up.attn = attn # 添加注意力层
# 如果不是最低分辨率,添加上采样层
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
# 更新当前分辨率
curr_res = curr_res * 2
# 将上采样模块插入列表的开头
self.up.insert(0, up) # prepend to get consistent order
# 结束层
# 创建归一化层
self.norm_out = Normalize(block_in)
# 创建输出卷积层
self.conv_out = make_conv_cls(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
# 定义一个私有方法,用于返回注意力机制的构造函数
def _make_attn(self) -> Callable:
return make_attn
# 定义一个私有方法,用于返回残差块的构造函数
def _make_resblock(self) -> Callable:
return ResnetBlock
# 定义一个私有方法,用于返回二维卷积层的构造函数
def _make_conv(self) -> Callable:
return torch.nn.Conv2d
# 获取最后一层的权重
def get_last_layer(self, **kwargs):
return self.conv_out.weight
# 前向传播方法,接收输入 z 和可选参数
def forward(self, z, **kwargs):
# 确保输入 z 的形状与预期相同(被注释掉的检查)
# assert z.shape[1:] == self.z_shape[1:]
# 记录输入 z 的形状
self.last_z_shape = z.shape
# 初始化时间步嵌入
temb = None
# 将输入 z 传入卷积层
h = self.conv_in(z)
# 中间处理
h = self.mid.block_1(h, temb, **kwargs) # 通过第一块中间块处理
if self.mid_attn: # 如果启用了中间注意力
h = self.mid.attn_1(h, **kwargs) # 应用中间注意力层
h = self.mid.block_2(h, temb, **kwargs) # 通过第二块中间块处理
# 上采样过程
for i_level in reversed(range(self.num_resolutions)): # 从最高分辨率到最低分辨率
for i_block in range(self.num_res_blocks + 1): # 遍历每个残差块
h = self.up[i_level].block[i_block](h, temb, **kwargs) # 通过上采样块处理
if len(self.up[i_level].attn) > 0: # 如果存在注意力层
h = self.up[i_level].attn[i_block](h, **kwargs) # 应用注意力层
if i_level != 0: # 如果不是最低分辨率
h = self.up[i_level].upsample(h) # 执行上采样
# 结束处理
if self.give_pre_end: # 如果启用了预处理结束返回
return h
h = self.norm_out(h) # 对输出进行归一化
h = nonlinearity(h) # 应用非线性激活函数
h = self.conv_out(h, **kwargs) # 通过最终卷积层处理
if self.tanh_out: # 如果启用了 Tanh 输出
h = torch.tanh(h) # 应用 Tanh 激活函数
return h # 返回最终输出
.\cogview3-finetune\sat\sgm\modules\diffusionmodules\openaimodel.py
# 导入操作系统模块,用于处理文件和目录操作
import os
# 导入数学模块,提供数学函数和常量
import math
# 从 abc 模块导入抽象方法装饰器,用于定义抽象基类
from abc import abstractmethod
# 从 functools 模块导入 partial 函数,用于偏函数应用
from functools import partial
# 从 typing 模块导入类型注解,用于类型提示
from typing import Iterable, List, Optional, Tuple, Union
# 导入 numpy 库,通常用于数值计算
import numpy as np
# 导入 torch 库,通常用于深度学习
import torch as th
# 导入 PyTorch 的神经网络模块
import torch.nn as nn
# 导入 PyTorch 的功能模块,提供激活函数等
import torch.nn.functional as F
# 从 einops 导入 rearrange 函数,用于重排张量
from einops import rearrange
# 导入自定义模块中的 SpatialTransformer 类
from ...modules.attention import SpatialTransformer
# 导入自定义模块中的实用函数
from ...modules.diffusionmodules.util import (
avg_pool_nd, # 平均池化函数
checkpoint, # 检查点函数
conv_nd, # 卷积函数
linear, # 线性变换函数
normalization, # 归一化函数
timestep_embedding, # 时间步嵌入函数
zero_module, # 零模块函数
)
# 导入自定义模块中的实用函数
from ...util import default, exists
# 定义一个空的占位函数,用于将模块转换为半精度浮点数
# dummy replace
def convert_module_to_f16(x):
pass
# 定义一个空的占位函数,用于将模块转换为单精度浮点数
def convert_module_to_f32(x):
pass
# 定义一个用于注意力池化的类,继承自 nn.Module
## go
class AttentionPool2d(nn.Module):
"""
从 CLIP 中改编: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
# 初始化方法,设置各类参数
def __init__(
self,
spacial_dim: int, # 空间维度
embed_dim: int, # 嵌入维度
num_heads_channels: int, # 头通道数量
output_dim: int = None, # 输出维度(可选)
):
# 调用父类初始化方法
super().__init__()
# 定义位置嵌入参数,初始化为正态分布
self.positional_embedding = nn.Parameter(
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
)
# 定义查询、键、值的卷积投影
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
# 定义输出的卷积投影
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
# 计算头的数量
self.num_heads = embed_dim // num_heads_channels
# 初始化注意力机制
self.attention = QKVAttention(self.num_heads)
# 前向传播方法
def forward(self, x):
# 获取输入的批次大小和通道数
b, c, *_spatial = x.shape
# 将输入重塑为 (批次, 通道, 高*宽) 的形状
x = x.reshape(b, c, -1) # NC(HW)
# 在最后一维上添加均值作为额外的特征
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
# 将位置嵌入加到输入上
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
# 对输入进行查询、键、值投影
x = self.qkv_proj(x)
# 应用注意力机制
x = self.attention(x)
# 对结果进行输出投影
x = self.c_proj(x)
# 返回第一个通道的结果
return x[:, :, 0]
# 定义一个时间步模块的基类,继承自 nn.Module
class TimestepBlock(nn.Module):
"""
任何模块的 forward() 方法接受时间步嵌入作为第二个参数。
"""
# 定义抽象的前向传播方法
@abstractmethod
def forward(self, x, emb):
"""
将模块应用于 `x`,并给定 `emb` 时间步嵌入。
"""
# 定义一个时间步嵌入的顺序模块,继承自 nn.Sequential 和 TimestepBlock
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
一个顺序模块,将时间步嵌入作为额外输入传递给支持的子模块。
"""
# 重写前向传播方法
def forward(
self,
x: th.Tensor, # 输入张量
emb: th.Tensor, # 时间步嵌入张量
context: Optional[th.Tensor] = None, # 上下文张量(可选)
):
# 遍历所有子模块
for layer in self:
module = layer
# 如果子模块是 TimestepBlock,则使用时间步嵌入进行计算
if isinstance(module, TimestepBlock):
x = layer(x, emb)
# 如果子模块是 SpatialTransformer,则使用上下文进行计算
elif isinstance(module, SpatialTransformer):
x = layer(x, context)
# 否则,仅使用输入进行计算
else:
x = layer(x)
# 返回最终的输出
return x
# 定义一个上采样模块,继承自 nn.Module
class Upsample(nn.Module):
"""
一个可选卷积的上采样层。
:param channels: 输入和输出的通道数。
:param use_conv: 布尔值,确定是否应用卷积。
:param dims: 确定信号是 1D、2D 还是 3D。如果是 3D,则在内两个维度上进行上采样。
"""
# 初始化方法,设置类的基本属性
def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False
):
# 调用父类初始化方法
super().__init__()
# 保存输入的通道数
self.channels = channels
# 如果没有指定输出通道数,则默认与输入通道数相同
self.out_channels = out_channels or channels
# 保存是否使用卷积的标志
self.use_conv = use_conv
# 保存维度信息
self.dims = dims
# 保存是否进行第三层上采样的标志
self.third_up = third_up
# 如果使用卷积,初始化卷积层
if use_conv:
self.conv = conv_nd(
dims, self.channels, self.out_channels, 3, padding=padding
)
# 前向传播方法,定义输入如何通过网络进行处理
def forward(self, x):
# 确保输入的通道数与初始化时指定的通道数一致
assert x.shape[1] == self.channels
# 如果输入为三维数据
if self.dims == 3:
# 根据是否需要第三层上采样确定时间因子
t_factor = 1 if not self.third_up else 2
# 对输入进行上采样
x = F.interpolate(
x,
(t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
mode="nearest",
)
else:
# 对输入进行上采样,比例因子为2
x = F.interpolate(x, scale_factor=2, mode="nearest")
# 如果使用卷积,则将输入通过卷积层处理
if self.use_conv:
x = self.conv(x)
# 返回处理后的输出
return x
# 定义一个转置上采样的类,继承自 nn.Module
class TransposedUpsample(nn.Module):
"Learned 2x upsampling without padding" # 文档字符串,描述该类的功能
# 初始化方法,设置输入通道、输出通道和卷积核大小
def __init__(self, channels, out_channels=None, ks=5):
super().__init__() # 调用父类的初始化方法
self.channels = channels # 保存输入通道数量
self.out_channels = out_channels or channels # 如果没有指定输出通道,则与输入通道相同
# 定义一个转置卷积层,用于上采样
self.up = nn.ConvTranspose2d(
self.channels, self.out_channels, kernel_size=ks, stride=2
)
# 前向传播方法,执行上采样操作
def forward(self, x):
return self.up(x) # 返回上采样后的结果
# 定义一个下采样层的类,继承自 nn.Module
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
# 初始化方法,设置输入通道、是否使用卷积、维度等参数
def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False
):
super().__init__() # 调用父类的初始化方法
self.channels = channels # 保存输入通道数量
self.out_channels = out_channels or channels # 如果没有指定输出通道,则与输入通道相同
self.use_conv = use_conv # 保存是否使用卷积的标志
self.dims = dims # 保存信号的维度
stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2)) # 确定步幅
if use_conv: # 如果使用卷积
# print(f"Building a Downsample layer with {dims} dims.") # 打印信息,表示正在构建下采样层
# print(
# f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
# f"kernel-size: 3, stride: {stride}, padding: {padding}"
# ) # 打印卷积层的设置参数
# if dims == 3:
# print(f" --> Downsampling third axis (time): {third_down}") # 打印是否在第三维进行下采样
# 定义卷积操作
self.op = conv_nd(
dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding,
)
else: # 如果不使用卷积
assert self.channels == self.out_channels # 确保输入通道与输出通道相同
# 定义平均池化操作
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
# 前向传播方法,执行下采样操作
def forward(self, x):
assert x.shape[1] == self.channels # 确保输入的通道数匹配
return self.op(x) # 返回下采样后的结果
# 定义一个残差块的类,继承自 TimestepBlock
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
# 初始化方法,用于创建类的实例
def __init__(
self,
channels, # 输入通道数
emb_channels, # 嵌入通道数
dropout, # 丢弃率
out_channels=None, # 输出通道数,默认为 None
use_conv=False, # 是否使用卷积
use_scale_shift_norm=False, # 是否使用缩放位移归一化
dims=2, # 数据维度,默认为 2
use_checkpoint=False, # 是否使用检查点
up=False, # 是否进行上采样
down=False, # 是否进行下采样
kernel_size=3, # 卷积核大小,默认为 3
exchange_temb_dims=False, # 是否交换时间嵌入维度
skip_t_emb=False, # 是否跳过时间嵌入
):
# 调用父类初始化方法
super().__init__()
# 设置输入通道数
self.channels = channels
# 设置嵌入通道数
self.emb_channels = emb_channels
# 设置丢弃率
self.dropout = dropout
# 设置输出通道数,如果未提供则默认与输入通道数相同
self.out_channels = out_channels or channels
# 设置是否使用卷积
self.use_conv = use_conv
# 设置是否使用检查点
self.use_checkpoint = use_checkpoint
# 设置是否使用缩放位移归一化
self.use_scale_shift_norm = use_scale_shift_norm
# 设置是否交换时间嵌入维度
self.exchange_temb_dims = exchange_temb_dims
# 如果卷积核大小是可迭代的,计算每个维度的填充大小
if isinstance(kernel_size, Iterable):
padding = [k // 2 for k in kernel_size]
else:
# 否则直接计算单个卷积核的填充大小
padding = kernel_size // 2
# 创建输入层的序列,包括归一化、激活函数和卷积操作
self.in_layers = nn.Sequential(
normalization(channels), # 归一化
nn.SiLU(), # SiLU 激活函数
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), # 卷积层
)
# 判断是否进行上采样或下采样
self.updown = up or down
# 如果进行上采样,初始化上采样层
if up:
self.h_upd = Upsample(channels, False, dims) # 上采样层
self.x_upd = Upsample(channels, False, dims) # 上采样层
# 如果进行下采样,初始化下采样层
elif down:
self.h_upd = Downsample(channels, False, dims) # 下采样层
self.x_upd = Downsample(channels, False, dims) # 下采样层
# 否则使用身份映射
else:
self.h_upd = self.x_upd = nn.Identity() # 身份映射层
# 设置是否跳过时间嵌入
self.skip_t_emb = skip_t_emb
# 根据是否使用缩放位移归一化计算嵌入输出通道数
self.emb_out_channels = (
2 * self.out_channels if use_scale_shift_norm else self.out_channels
)
# 如果跳过时间嵌入,输出警告并设置嵌入层为 None
if self.skip_t_emb:
print(f"Skipping timestep embedding in {self.__class__.__name__}") # 警告信息
assert not self.use_scale_shift_norm # 确保不使用缩放位移归一化
self.emb_layers = None # 嵌入层设置为 None
self.exchange_temb_dims = False # 不交换时间嵌入维度
# 否则创建嵌入层的序列
else:
self.emb_layers = nn.Sequential(
nn.SiLU(), # SiLU 激活函数
linear(
emb_channels, # 嵌入通道数
self.emb_out_channels, # 嵌入输出通道数
),
)
# 创建输出层的序列,包括归一化、激活函数、丢弃层和卷积层
self.out_layers = nn.Sequential(
normalization(self.out_channels), # 归一化
nn.SiLU(), # SiLU 激活函数
nn.Dropout(p=dropout), # 丢弃层
zero_module(
conv_nd(
dims, # 数据维度
self.out_channels, # 输出通道数
self.out_channels, # 输出通道数
kernel_size, # 卷积核大小
padding=padding, # 填充
)
), # 卷积层
)
# 根据输入和输出通道数设置跳过连接
if self.out_channels == channels:
self.skip_connection = nn.Identity() # 身份映射层
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, kernel_size, padding=padding # 卷积层
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) # 卷积层,卷积核大小为 1
# 定义前向传播函数,接受输入张量和时间步嵌入
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
# 调用检查点函数以保存中间计算结果,减少内存使用
return checkpoint(
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
# 定义实际的前向传播逻辑
def _forward(self, x, emb):
# 如果设置了 updown,则进行上采样和下采样
if self.updown:
# 分离输入层的最后一层和其他层
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
# 通过其他输入层处理输入 x
h = in_rest(x)
# 更新隐藏状态
h = self.h_upd(h)
# 更新输入 x
x = self.x_upd(x)
# 通过卷积层处理隐藏状态
h = in_conv(h)
else:
# 直接通过输入层处理输入 x
h = self.in_layers(x)
# 如果跳过时间嵌入,则初始化嵌入输出为零张量
if self.skip_t_emb:
emb_out = th.zeros_like(h)
else:
# 通过嵌入层处理时间嵌入,确保数据类型与 h 一致
emb_out = self.emb_layers(emb).type(h.dtype)
# 扩展 emb_out 的形状以匹配 h 的形状
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
# 如果使用缩放和偏移规范化
if self.use_scale_shift_norm:
# 分离输出层中的规范化层和其他层
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# 将嵌入输出分割为缩放和偏移
scale, shift = th.chunk(emb_out, 2, dim=1)
# 对隐藏状态进行规范化并应用缩放和偏移
h = out_norm(h) * (1 + scale) + shift
# 通过剩余的输出层处理隐藏状态
h = out_rest(h)
else:
# 如果交换时间嵌入的维度
if self.exchange_temb_dims:
# 重新排列嵌入输出的维度
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
# 将嵌入输出与隐藏状态相加
h = h + emb_out
# 通过输出层处理隐藏状态
h = self.out_layers(h)
# 返回输入 x 与处理后的隐藏状态的跳跃连接
return self.skip_connection(x) + h
# 定义一个注意力模块,允许空间位置相互关注
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
# 初始化方法,定义模块的基本参数
def __init__(
self,
channels, # 输入通道数
num_heads=1, # 注意力头的数量,默认为1
num_head_channels=-1, # 每个头的通道数,默认为-1
use_checkpoint=False, # 是否使用检查点
use_new_attention_order=False, # 是否使用新的注意力顺序
):
# 调用父类初始化方法
super().__init__()
self.channels = channels # 保存输入通道数
# 判断 num_head_channels 是否为 -1
if num_head_channels == -1:
self.num_heads = num_heads # 如果为 -1,直接使用 num_heads
else:
# 断言通道数可以被 num_head_channels 整除
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels # 计算头的数量
self.use_checkpoint = use_checkpoint # 保存检查点标志
self.norm = normalization(channels) # 初始化归一化层
self.qkv = conv_nd(1, channels, channels * 3, 1) # 创建卷积层用于计算 q, k, v
# 根据是否使用新注意力顺序选择相应的注意力类
if use_new_attention_order:
# 在分割头之前分割 qkv
self.attention = QKVAttention(self.num_heads)
else:
# 在分割 qkv 之前分割头
self.attention = QKVAttentionLegacy(self.num_heads)
# 初始化输出投影层
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
# 前向传播方法
def forward(self, x, **kwargs):
# TODO 添加跨帧注意力并使用混合检查点
# 使用检查点机制来调用内部前向传播函数
return checkpoint(
self._forward, (x,), self.parameters(), True
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
# return pt_checkpoint(self._forward, x) # pytorch
# 内部前向传播方法
def _forward(self, x):
b, c, *spatial = x.shape # 解包输入张量的形状
x = x.reshape(b, c, -1) # 将输入张量重塑为 (batch_size, channels, spatial_dim)
qkv = self.qkv(self.norm(x)) # 计算 q, k, v
h = self.attention(qkv) # 应用注意力机制
h = self.proj_out(h) # 对注意力结果进行投影
return (x + h).reshape(b, c, *spatial) # 返回重塑后的结果
# 计算注意力操作的 FLOPS
def count_flops_attn(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape # 解包输入张量的形状
num_spatial = int(np.prod(spatial)) # 计算空间维度的总数
# 进行两个矩阵乘法,具有相同数量的操作。
# 第一个计算权重矩阵,第二个计算值向量的组合。
matmul_ops = 2 * b * (num_spatial**2) * c # 计算矩阵乘法的操作数
model.total_ops += th.DoubleTensor([matmul_ops]) # 将操作数累加到模型的总操作数中
# 旧版 QKV 注意力模块
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
# 初始化方法,设置注意力头的数量
def __init__(self, n_heads):
super().__init__() # 调用父类初始化方法
self.n_heads = n_heads # 保存注意力头的数量
# 定义前向传播方法,接收 QKV 张量
def forward(self, qkv):
"""
应用 QKV 注意力机制。
:param qkv: 一个形状为 [N x (H * 3 * C) x T] 的张量,包含 Q、K 和 V。
:return: 一个形状为 [N x (H * C) x T] 的张量,经过注意力处理后输出。
"""
# 获取输入张量的批量大小、宽度和长度
bs, width, length = qkv.shape
# 确保宽度可以被 (3 * n_heads) 整除,以分割 Q、K 和 V
assert width % (3 * self.n_heads) == 0
# 计算每个头的通道数
ch = width // (3 * self.n_heads)
# 将 qkv 张量重塑并分割成 Q、K 和 V 三个部分
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
# 计算缩放因子,用于稳定性
scale = 1 / math.sqrt(math.sqrt(ch))
# 使用爱因斯坦求和约定计算注意力权重,乘以缩放因子
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # 使用 f16 比后续除法更稳定
# 对权重进行 softmax 归一化,并保持原始数据类型
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
# 根据权重和 V 计算输出张量
a = th.einsum("bts,bcs->bct", weight, v)
# 将输出张量重塑为原始批量大小和通道数
return a.reshape(bs, -1, length)
# 定义静态方法以计算模型的浮点运算数
@staticmethod
def count_flops(model, _x, y):
# 调用辅助函数计算注意力层的浮点运算数
return count_flops_attn(model, _x, y)
# 定义一个名为 QKVAttention 的类,继承自 nn.Module
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
# 初始化方法,接收注意力头的数量
def __init__(self, n_heads):
super().__init__() # 调用父类的初始化方法
self.n_heads = n_heads # 保存注意力头的数量
# 前向传播方法,接收 qkv 张量并执行注意力计算
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape # 解包 qkv 张量的维度
assert width % (3 * self.n_heads) == 0 # 确保宽度能够被注意力头数量整除
ch = width // (3 * self.n_heads) # 计算每个头的通道数
q, k, v = qkv.chunk(3, dim=1) # 将 qkv 张量分成 Q, K, V 三部分
scale = 1 / math.sqrt(math.sqrt(ch)) # 计算缩放因子
weight = th.einsum(
"bct,bcs->bts", # 定义爱因斯坦求和约定,计算权重
(q * scale).view(bs * self.n_heads, ch, length), # 缩放后的 Q 重塑形状
(k * scale).view(bs * self.n_heads, ch, length), # 缩放后的 K 重塑形状
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) # 计算权重的 softmax,确保其和为 1
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) # 计算最终的注意力输出
return a.reshape(bs, -1, length) # 将输出重塑回原始批量形状
@staticmethod
# 计算 FLOPs 的静态方法
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y) # 调用函数计算注意力层的 FLOPs
# 定义一个名为 Timestep 的类,继承自 nn.Module
class Timestep(nn.Module):
def __init__(self, dim):
super().__init__() # 调用父类的初始化方法
self.dim = dim # 保存时间步的维度
# 前向传播方法,接收时间步张量
def forward(self, t):
return timestep_embedding(t, self.dim) # 调用时间步嵌入函数
# 定义一个字典,将字符串类型映射到对应的 PyTorch 数据类型
str_to_dtype = {
"fp32": th.float32, # fp32 对应 float32
"fp16": th.float16, # fp16 对应 float16
"bf16": th.bfloat16 # bf16 对应 bfloat16
}
# 定义一个名为 UNetModel 的类,继承自 nn.Module
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
"""
# 参数 resblock_updown:是否在上采样/下采样过程中使用残差块
# 参数 use_new_attention_order:是否使用不同的注意力模式以提高效率
"""
# 初始化方法
def __init__(
# 输入通道数
self,
in_channels,
# 模型通道数
model_channels,
# 输出通道数
out_channels,
# 残差块的数量
num_res_blocks,
# 注意力分辨率
attention_resolutions,
# dropout 比例,默认为 0
dropout=0,
# 通道的倍增因子,默认值为 (1, 2, 4, 8)
channel_mult=(1, 2, 4, 8),
# 是否使用卷积重采样,默认为 True
conv_resample=True,
# 数据维度,默认为 2
dims=2,
# 类别数,默认为 None
num_classes=None,
# 是否使用检查点,默认为 False
use_checkpoint=False,
# 是否使用 fp16 精度,默认为 False
use_fp16=False,
# 注意力头数,默认为 -1
num_heads=-1,
# 每个头的通道数,默认为 -1
num_head_channels=-1,
# 上采样时的头数,默认为 -1
num_heads_upsample=-1,
# 是否使用尺度偏移归一化,默认为 False
use_scale_shift_norm=False,
# 是否使用残差块进行上采样/下采样,默认为 False
resblock_updown=False,
# 是否使用新的注意力顺序,默认为 False
use_new_attention_order=False,
# 是否使用空间变换器,支持自定义变换器
use_spatial_transformer=False, # custom transformer support
# 变换器的深度,默认为 1
transformer_depth=1, # custom transformer support
# 上下文维度,默认为 None
context_dim=None, # custom transformer support
# 嵌入数,默认为 None
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
# 是否使用传统模式,默认为 True
legacy=True,
# 是否禁用自注意力,默认为 None
disable_self_attentions=None,
# 注意力块的数量,默认为 None
num_attention_blocks=None,
# 是否禁用中间自注意力,默认为 False
disable_middle_self_attn=False,
# 是否在变换器中使用线性输入,默认为 False
use_linear_in_transformer=False,
# 空间变换器的注意力类型,默认为 "softmax"
spatial_transformer_attn_type="softmax",
# 输入通道数,默认为 None
adm_in_channels=None,
# 是否使用 Fairscale 检查点,默认为 False
use_fairscale_checkpoint=False,
# 是否将计算卸载到 CPU,默认为 False
offload_to_cpu=False,
# 中间变换器的深度,默认为 None
transformer_depth_middle=None,
# 配置条件嵌入维度,默认为 None
cfg_cond_embed_dim=None,
# 数据类型,默认为 "fp32"
dtype="fp32",
# 将模型的主体转换为 float16
def convert_to_fp16(self):
"""
将模型的主体转换为 float16。
"""
# 对输入块应用转换模块,将其转换为 float16
self.input_blocks.apply(convert_module_to_f16)
# 对中间块应用转换模块,将其转换为 float16
self.middle_block.apply(convert_module_to_f16)
# 对输出块应用转换模块,将其转换为 float16
self.output_blocks.apply(convert_module_to_f16)
# 将模型的主体转换为 float32
def convert_to_fp32(self):
"""
将模型的主体转换为 float32。
"""
# 对输入块应用转换模块,将其转换为 float32
self.input_blocks.apply(convert_module_to_f32)
# 对中间块应用转换模块,将其转换为 float32
self.middle_block.apply(convert_module_to_f32)
# 对输出块应用转换模块,将其转换为 float32
self.output_blocks.apply(convert_module_to_f32)
# 定义前向传播函数,接收输入数据和其他参数
def forward(self, x, timesteps=None, context=None, y=None, scale_emb=None, **kwargs):
"""
应用模型于输入批次。
:param x: 输入张量,形状为 [N x C x ...]。
:param timesteps: 一维时间步批次。
:param context: 通过 crossattn 插入的条件信息。
:param y: 标签张量,形状为 [N],如果是类条件。
:return: 输出张量,形状为 [N x C x ...]。
"""
# 如果输入数据类型不匹配,则转换为模型所需的数据类型
if x.dtype != self.dtype:
x = x.to(self.dtype)
# 确保 y 的存在性与类数设置一致
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
# 初始化存储中间结果的列表
hs = []
# 生成时间步嵌入
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
# 如果提供了缩放嵌入,则进行相应处理
if scale_emb is not None:
assert hasattr(self, "w_proj"), "w_proj not found in the model"
t_emb = t_emb + self.w_proj(scale_emb.to(self.dtype))
# 通过时间嵌入生成最终嵌入
emb = self.time_embed(t_emb)
# 如果模型是类条件,则将标签嵌入加入到最终嵌入中
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
# 将输入数据赋值给 h
# h = x.type(self.dtype)
h = x
# 通过输入模块处理 h,并保存中间结果
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
# 通过中间模块进一步处理 h
h = self.middle_block(h, emb, context)
# 通过输出模块处理 h,并逐层合并中间结果
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
# 将 h 转换回原输入数据类型
h = h.type(x.dtype)
# 检查是否支持预测码本 ID
if self.predict_codebook_ids:
assert False, "not supported anymore. what the f*** are you doing?"
else:
# 返回最终输出结果
return self.out(h)
.\cogview3-finetune\sat\sgm\modules\diffusionmodules\sampling.py
# 部分代码移植自 https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
"""
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
"""
# 从 typing 模块导入字典和联合类型
from typing import Dict, Union
# 导入 PyTorch 库
import torch
# 从 omegaconf 模块导入配置相关的类
from omegaconf import ListConfig, OmegaConf
# 导入 tqdm 库用于显示进度条
from tqdm import tqdm
# 从相对路径模块导入采样相关的工具函数
from ...modules.diffusionmodules.sampling_utils import (
get_ancestral_step, # 获取祖先步骤
linear_multistep_coeff, # 线性多步骤系数
to_d, # 转换为 d
to_neg_log_sigma, # 转换为负对数sigma
to_sigma, # 转换为 sigma
)
# 从相对路径模块导入离散化工具
from ...modules.diffusionmodules.discretizer import generate_roughly_equally_spaced_steps
# 从相对路径模块导入工具函数
from ...util import append_dims, default, instantiate_from_config
# 定义默认引导器配置
DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
# 定义用于生成引导嵌入的函数
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
"""
参考文献: https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
在这些时间步生成嵌入向量
embedding_dim (`int`, *可选*, 默认为 512):
生成的嵌入的维度
dtype:
生成嵌入的数据类型
Returns:
`torch.FloatTensor`: 形状为 `(len(timesteps), embedding_dim)` 的嵌入向量
"""
# 确保输入张量是一个一维张量
assert len(w.shape) == 1
# 将输入乘以 1000.0
w = w * 1000.0
# 计算嵌入维度的一半
half_dim = embedding_dim // 2
# 计算基础嵌入的系数
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
# 生成嵌入基础,转换为指数形式并调整为目标设备和数据类型
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb).to(w.device).to(w.dtype)
# 生成最终的嵌入向量
emb = w.to(dtype)[:, None] * emb[None, :]
# 将正弦和余弦值连接在一起
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
# 如果嵌入维度为奇数,进行零填充
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
# 确保生成的嵌入形状与预期一致
assert emb.shape == (w.shape[0], embedding_dim)
# 返回生成的嵌入向量
return emb
# 定义基础扩散采样器类
class BaseDiffusionSampler:
# 初始化采样器
def __init__(
self,
discretization_config: Union[Dict, ListConfig, OmegaConf], # 离散化配置
num_steps: Union[int, None] = None, # 采样步数,默认为 None
guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, # 引导器配置,默认为 None
cfg_cond_scale: Union[int, None] = None, # 条件缩放参数,默认为 None
cfg_cond_embed_dim: Union[int, None] = 256, # 条件嵌入维度,默认为 256
verbose: bool = False, # 是否显示详细信息
device: str = "cuda", # 设备类型,默认为 CUDA
):
# 设置采样步数
self.num_steps = num_steps
# 实例化离散化配置
self.discretization = instantiate_from_config(discretization_config)
# 实例化引导器配置
self.guider = instantiate_from_config(
default(
guider_config,
DEFAULT_GUIDER,
)
)
# 设置条件参数
self.cfg_cond_scale = cfg_cond_scale
self.cfg_cond_embed_dim = cfg_cond_embed_dim
# 设置详细模式和设备
self.verbose = verbose
self.device = device
# 准备采样循环的函数
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
# 生成 sigma 值
sigmas = self.discretization(
self.num_steps if num_steps is None else num_steps, device=self.device
)
# 默认使用条件
uc = default(uc, cond)
# 根据 sigma 计算 x 的调整
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
# 获取 sigma 的数量
num_sigmas = len(sigmas)
# 创建新的一维张量 s_in,初始值为 1
s_in = x.new_ones([x.shape[0]]).float()
# 返回调整后的 x 和其他参数
return x, s_in, sigmas, num_sigmas, cond, uc
# 定义去噪函数,接受输入x、去噪器denoiser、噪声水平sigma、条件cond和无条件uc
def denoise(self, x, denoiser, sigma, cond, uc):
# 检查条件缩放系数是否不为None
if self.cfg_cond_scale is not None:
# 获取输入批次的大小
batch_size = x.shape[0]
# 创建与批次大小相同的全1张量,并乘以条件缩放系数,生成缩放嵌入
scale_emb = guidance_scale_embedding(torch.ones(batch_size, device=x.device) * self.cfg_cond_scale, embedding_dim=self.cfg_cond_embed_dim, dtype=x.dtype)
# 使用去噪器处理输入,传入缩放嵌入
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), scale_emb=scale_emb)
else:
# 若无条件缩放系数,直接使用去噪器处理输入
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
# 对去噪后的结果进行进一步引导处理
denoised = self.guider(denoised, sigma)
# 返回最终去噪结果
return denoised
# 定义生成sigma的函数,接受sigma数量num_sigmas
def get_sigma_gen(self, num_sigmas):
# 创建一个范围生成器,从0到num_sigmas-1
sigma_generator = range(num_sigmas - 1)
# 如果启用了详细输出
if self.verbose:
# 打印分隔线和采样设置信息
print("#" * 30, " Sampling setting ", "#" * 30)
print(f"Sampler: {self.__class__.__name__}")
print(f"Discretization: {self.discretization.__class__.__name__}")
print(f"Guider: {self.guider.__class__.__name__}")
# 使用tqdm包装生成器以显示进度条
sigma_generator = tqdm(
sigma_generator,
total=num_sigmas,
desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
)
# 返回sigma生成器
return sigma_generator
# 定义一个单步扩散采样器类,继承自基本扩散采样器
class SingleStepDiffusionSampler(BaseDiffusionSampler):
# 定义采样步骤方法,未实现
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
# 抛出未实现错误,表明该方法需在子类中实现
raise NotImplementedError
# 定义欧拉步骤方法,用于计算下一个状态
def euler_step(self, x, d, dt):
# 返回更新后的状态,基于当前状态、导数和时间增量
return x + dt * d
# 定义 EDM 采样器类,继承自单步扩散采样器
class EDMSampler(SingleStepDiffusionSampler):
# 初始化 EDM 采样器的参数
def __init__(
self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
):
# 调用父类的初始化方法
super().__init__(*args, **kwargs)
# 设置采样器的参数
self.s_churn = s_churn # 变化率
self.s_tmin = s_tmin # 最小时间
self.s_tmax = s_tmax # 最大时间
self.s_noise = s_noise # 噪声强度
# 定义采样步骤方法
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
# 计算调整后的 sigma 值
sigma_hat = sigma * (gamma + 1.0)
# 如果 gamma 大于 0,加入噪声
if gamma > 0:
# 生成与 x 形状相同的随机噪声
eps = torch.randn_like(x) * self.s_noise
# 更新 x 的值,加入噪声
x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
# 去噪,得到去噪后的结果
denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
# 计算导数
d = to_d(x, sigma_hat, denoised)
# 计算时间增量
dt = append_dims(next_sigma - sigma_hat, x.ndim)
# 执行欧拉步骤,更新 x
euler_step = self.euler_step(x, d, dt)
# 进行可能的修正步骤,得到最终的 x
x = self.possible_correction_step(
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
)
# 返回更新后的 x
return x
# 定义调用方法
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
# 准备采样循环所需的参数
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
# 遍历 sigma 值
for i in self.get_sigma_gen(num_sigmas):
# 计算 gamma 值
gamma = (
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
if self.s_tmin <= sigmas[i] <= self.s_tmax
else 0.0
)
# 执行采样步骤,更新 x
x = self.sampler_step(
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
cond,
uc,
gamma,
)
# 返回最终的 x
return x
# 定义 DDIM 采样器类,继承自单步扩散采样器
class DDIMSampler(SingleStepDiffusionSampler):
# 初始化 DDIM 采样器的参数
def __init__(
self, s_noise=0.1, *args, **kwargs
):
# 调用父类的初始化方法
super().__init__(*args, **kwargs)
# 设置噪声强度
self.s_noise = s_noise
# 定义采样步骤方法
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):
# 去噪,得到去噪后的结果
denoised = self.denoise(x, denoiser, sigma, cond, uc)
# 计算导数
d = to_d(x, sigma, denoised)
# 计算时间增量
dt = append_dims(next_sigma * (1 - s_noise**2)**0.5 - sigma, x.ndim)
# 计算欧拉步骤,加入噪声
euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)
# 进行可能的修正步骤,得到最终的 x
x = self.possible_correction_step(
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
)
# 返回更新后的 x
return x
# 定义一个可调用的类方法,接收去噪器、输入数据、条件及其他参数
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
# 准备采样循环,返回处理后的数据和相关参数
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
# 遍历生成的 sigma 值
for i in self.get_sigma_gen(num_sigmas):
# 执行采样步骤,更新输入数据 x
x = self.sampler_step(
s_in * sigmas[i], # 当前 sigma 乘以输入信号
s_in * sigmas[i + 1],# 下一个 sigma 乘以输入信号
denoiser, # 传递去噪器
x, # 当前数据
cond, # 条件信息
uc, # 可选的额外条件
self.s_noise, # 传递噪声信息
)
# 返回最终处理后的数据
return x
# 定义一个继承自 SingleStepDiffusionSampler 的类 AncestralSampler
class AncestralSampler(SingleStepDiffusionSampler):
# 初始化方法,设定默认参数 eta 和 s_noise
def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
# 调用父类的初始化方法
super().__init__(*args, **kwargs)
# 设置 eta 属性
self.eta = eta
# 设置 s_noise 属性
self.s_noise = s_noise
# 定义噪声采样器,生成与输入形状相同的随机噪声
self.noise_sampler = lambda x: torch.randn_like(x)
# 定义 ancestral_euler_step 方法,用于执行欧拉步长
def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
# 计算偏导数 d
d = to_d(x, sigma, denoised)
# 将 sigma_down 和 sigma 的差值扩展到 x 的维度
dt = append_dims(sigma_down - sigma, x.ndim)
# 返回欧拉步长的结果
return self.euler_step(x, d, dt)
# 定义 ancestral_step 方法,执行采样步骤
def ancestral_step(self, x, sigma, next_sigma, sigma_up):
# 根据条件选择更新 x 的值
x = torch.where(
append_dims(next_sigma, x.ndim) > 0.0, # 检查 next_sigma 是否大于 0
x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim), # 更新 x 的值
x, # 保持原值
)
# 返回更新后的 x
return x
# 定义调用方法,使得类可以被调用
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
# 准备采样循环,获取必要的输入
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
# 遍历 sigma 生成器,进行采样步骤
for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step(
s_in * sigmas[i], # 当前 sigma 值
s_in * sigmas[i + 1], # 下一个 sigma 值
denoiser, # 去噪器
x, # 当前 x 值
cond, # 条件
uc, # 额外条件
)
# 返回最终的 x 值
return x
# 定义一个继承自 BaseDiffusionSampler 的类 LinearMultistepSampler
class LinearMultistepSampler(BaseDiffusionSampler):
# 初始化方法,设定默认的 order 参数
def __init__(
self,
order=4,
*args,
**kwargs,
):
# 调用父类的初始化方法
super().__init__(*args, **kwargs)
# 设置 order 属性
self.order = order
# 定义调用方法,使得类可以被调用
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
# 准备采样循环,获取必要的输入
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
# 初始化一个列表 ds 用于存储导数
ds = []
# 将 sigmas 从 GPU 移到 CPU,并转换为 numpy 数组
sigmas_cpu = sigmas.detach().cpu().numpy()
# 遍历 sigma 生成器
for i in self.get_sigma_gen(num_sigmas):
# 计算当前的 sigma
sigma = s_in * sigmas[i]
# 使用去噪器处理当前输入
denoised = denoiser(
*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
)
# 使用引导函数对去噪结果进行处理
denoised = self.guider(denoised, sigma)
# 计算导数 d
d = to_d(x, sigma, denoised)
# 将导数添加到列表 ds
ds.append(d)
# 如果 ds 的长度超过 order,移除最早的元素
if len(ds) > self.order:
ds.pop(0)
# 计算当前的阶数
cur_order = min(i + 1, self.order)
# 计算当前阶数的线性多步系数
coeffs = [
linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
for j in range(cur_order)
]
# 更新 x 值
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
# 返回最终的 x 值
return x
# 定义一个继承自 EDMSampler 的类 EulerEDMSampler
class EulerEDMSampler(EDMSampler):
# 定义可能的校正步骤方法
def possible_correction_step(
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
):
# 返回 euler_step,表示不进行额外的校正
return euler_step
# 定义一个继承自 EDMSampler 的类 HeunEDMSampler
class HeunEDMSampler(EDMSampler):
# 定义可能的校正步骤方法
def possible_correction_step(
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
):
):
# 如果下一个噪声水平的总和小于一个非常小的阈值
if torch.sum(next_sigma) < 1e-14:
# 如果所有噪声水平为0,保存网络评估的结果
return euler_step
else:
# 使用去噪器对当前步进行去噪处理
denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
# 将去噪后的结果转换为新数据
d_new = to_d(euler_step, next_sigma, denoised)
# 计算当前数据与新数据的平均值
d_prime = (d + d_new) / 2.0
# 如果噪声水平不为0,则应用修正
x = torch.where(
# 检查噪声水平是否大于0,决定是否修正
append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
)
# 返回修正后的结果
return x
# 定义一个 Euler 祖先采样器类,继承自 AncestralSampler
class EulerAncestralSampler(AncestralSampler):
# 定义采样步骤的方法,接受多个参数
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
# 获取下一个采样步的 sigma 值
sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
# 使用去噪器对当前输入进行去噪
denoised = self.denoise(x, denoiser, sigma, cond, uc)
# 使用 Euler 方法更新 x 的值
x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
# 应用祖先步骤更新 x 的值
x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
# 返回更新后的 x
return x
# 定义一个 DPMPP2S 祖先采样器类,继承自 AncestralSampler
class DPMPP2SAncestralSampler(AncestralSampler):
# 获取变量的方法,计算相关参数
def get_variables(self, sigma, sigma_down):
# 将 sigma 和 sigma_down 转换为负对数形式
t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
# 计算时间间隔 h
h = t_next - t
# 计算 s 值
s = t + 0.5 * h
# 返回计算的参数
return h, s, t, t_next
# 获取乘法因子的方法
def get_mult(self, h, s, t, t_next):
# 计算各个乘法因子
mult1 = to_sigma(s) / to_sigma(t)
mult2 = (-0.5 * h).expm1()
mult3 = to_sigma(t_next) / to_sigma(t)
mult4 = (-h).expm1()
# 返回所有乘法因子
return mult1, mult2, mult3, mult4
# 采样步骤的方法,执行多个计算步骤
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
# 获取下一个采样步的 sigma 值
sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
# 对输入进行去噪
denoised = self.denoise(x, denoiser, sigma, cond, uc)
# 使用 Euler 方法更新 x 的值
x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
# 检查 sigma_down 是否接近于零
if torch.sum(sigma_down) < 1e-14:
# 如果噪声级别为 0,则保存网络评估
x = x_euler
else:
# 获取变量 h, s, t, t_next
h, s, t, t_next = self.get_variables(sigma, sigma_down)
# 获取乘法因子,并调整维度
mult = [
append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
]
# 更新 x 的值
x2 = mult[0] * x - mult[1] * denoised
# 对 x2 进行去噪
denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
# 计算最终的 x 值
x_dpmpp2s = mult[2] * x - mult[3] * denoised2
# 如果噪声级别不为 0,则应用校正
x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
# 最终应用祖先步骤更新 x
x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
# 返回更新后的 x
return x
# 定义一个 DPMPP2M 采样器类,继承自 BaseDiffusionSampler
class DPMPP2MSampler(BaseDiffusionSampler):
# 获取变量的方法,计算相关参数
def get_variables(self, sigma, next_sigma, previous_sigma=None):
# 将 sigma 和 next_sigma 转换为负对数形式
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
# 计算时间间隔 h
h = t_next - t
# 如果提供了 previous_sigma,则进行额外计算
if previous_sigma is not None:
h_last = t - to_neg_log_sigma(previous_sigma)
r = h_last / h
return h, r, t, t_next
else:
# 如果没有提供,则返回 h 和 t 值
return h, None, t, t_next
# 获取乘法因子的方法
def get_mult(self, h, r, t, t_next, previous_sigma):
# 计算基础乘法因子
mult1 = to_sigma(t_next) / to_sigma(t)
mult2 = (-h).expm1()
# 如果提供了 previous_sigma,则计算额外的乘法因子
if previous_sigma is not None:
mult3 = 1 + 1 / (2 * r)
mult4 = 1 / (2 * r)
return mult1, mult2, mult3, mult4
else:
# 返回基本的乘法因子
return mult1, mult2
# 采样步骤的方法,执行多个计算步骤
def sampler_step(
self,
old_denoised,
previous_sigma,
sigma,
next_sigma,
denoiser,
x,
cond,
uc=None,
):
# 使用去噪器对输入数据进行去噪,返回去噪后的结果
denoised = self.denoise(x, denoiser, sigma, cond, uc)
# 获取当前和下一个噪声级别相关的变量
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
# 计算多重系数,扩展维度以匹配输入数据的维度
mult = [
append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
]
# 计算标准化后的输出
x_standard = mult[0] * x - mult[1] * denoised
# 检查之前的去噪结果是否存在或下一噪声级别是否接近零
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
# 如果噪声级别为零或处于第一步,返回标准化结果和去噪结果
return x_standard, denoised
else:
# 计算去噪后的数据修正值
denoised_d = mult[2] * denoised - mult[3] * old_denoised
# 计算高级输出
x_advanced = mult[0] * x - mult[1] * denoised_d
# 如果噪声级别不为零且不是第一步,应用修正
x = torch.where(
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
)
# 返回最终输出和去噪结果
return x, denoised
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
# 准备采样循环,包括输入数据和条件信息的处理
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
old_denoised = None
# 遍历噪声级别生成器
for i in self.get_sigma_gen(num_sigmas):
# 在每个步骤中执行采样,更新去噪结果
x, old_denoised = self.sampler_step(
old_denoised,
None if i == 0 else s_in * sigmas[i - 1],
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
cond,
uc=uc,
)
# 返回最终的去噪结果
return x
# 定义一个将输入信号传递到去噪器的函数
def relay_to_d(x, sigma, denoised, image, step, total_step):
# 计算模糊度的变化量
blurring_d = (denoised - image) / total_step
# 根据模糊度和当前步长更新去噪图像
blurring_denoised = image + blurring_d * step
# 计算当前信号与去噪信号的差异,标准化为 sigma 的维度
d = (x - blurring_denoised) / append_dims(sigma, x.ndim)
# 返回计算得到的差异和模糊度变化
return d, blurring_d
# 定义一个线性中继EDM采样器,继承自EulerEDMSampler
class LinearRelayEDMSampler(EulerEDMSampler):
# 初始化函数,设定部分步数
def __init__(self, partial_num_steps=20, *args, **kwargs):
# 调用父类初始化方法
super().__init__(*args, **kwargs)
# 设置部分步数
self.partial_num_steps = partial_num_steps
# 定义采样调用方法
def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None):
# 克隆随机数以保持不变
randn_unit = randn.clone()
# 准备采样循环,获取相关参数
randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
randn, cond, uc, num_steps
)
# 初始化 x 为 None
x = None
# 遍历生成的 sigma 值
for i in self.get_sigma_gen(num_sigmas):
# 如果当前步数小于总步数减去部分步数,继续下一次循环
if i < self.num_steps - self.partial_num_steps:
continue
# 如果 x 还未初始化,则根据图像和随机数计算初始值
if x is None:
x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape))
# 计算 gamma 值,控制采样过程中的噪声
gamma = (
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
if self.s_tmin <= sigmas[i] <= self.s_tmax
else 0.0
)
# 进行一次采样步骤
x = self.sampler_step(
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
cond,
uc,
gamma,
step=i - self.num_steps + self.partial_num_steps,
image=image,
index=self.num_steps - i,
)
# 返回最终的图像
return x
# 定义欧拉步骤的计算方法
def euler_step(self, x, d, dt, blurring_d):
# 更新 x 的值
return x + dt * d + blurring_d
# 定义采样步骤的计算方法
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0, step=None, image=None, index=None):
# 计算 sigma_hat,考虑 gamma 的影响
sigma_hat = sigma * (gamma + 1.0)
# 如果 gamma 大于 0,添加噪声
if gamma > 0:
eps = torch.randn_like(x) * self.s_noise
x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
# 使用去噪器去噪当前图像
denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
# 计算 beta_t,控制去噪过程
beta_t = next_sigma / sigma_hat * index / self.partial_num_steps - (index - 1) / self.partial_num_steps
# 更新 x 的值,结合去噪结果
x = x * append_dims(next_sigma / sigma_hat, x.ndim) + denoised * append_dims(1 - next_sigma / sigma_hat + beta_t, x.ndim) - image * append_dims(beta_t, x.ndim)
# 返回更新后的图像
return x
# 定义零信噪比DDIM采样器,继承自SingleStepDiffusionSampler
class ZeroSNRDDIMSampler(SingleStepDiffusionSampler):
# 初始化函数,设定是否使用条件生成
def __init__(
self,
do_cfg=True,
*args,
**kwargs,
):
# 调用父类初始化方法
super().__init__(*args, **kwargs)
# 设置条件生成标志
self.do_cfg = do_cfg
# 准备采样循环的参数
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
# 计算累积的 alpha 值,并获取对应的索引
alpha_cumprod_sqrt, indices = self.discretization(
self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True
)
# 如果 uc 为 None,则使用 cond
uc = default(uc, cond)
# 获取 sigma 的数量
num_sigmas = len(alpha_cumprod_sqrt)
# 初始化 s_in 为全 1 向量
s_in = x.new_ones([x.shape[0]])
# 返回准备好的参数
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, indices
# 定义去噪函数,接受输入数据和其他参数
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, i=None, idx=None):
# 初始化额外的模型输入字典
additional_model_inputs = {}
# 如果启用 CFG,准备包含索引的输入
if self.do_cfg:
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * idx] * 2)
# 否则只准备单个索引输入
else:
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * idx])
# 使用去噪器处理准备好的输入和额外参数,得到去噪后的结果
denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs)
# 使用引导器进一步处理去噪后的结果
denoised = self.guider(denoised, alpha_cumprod_sqrt, step=i, num_steps=self.num_steps)
# 返回去噪后的结果
return denoised
# 定义采样步骤函数,执行去噪和更新过程
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, i=None, idx=None, return_denoised=False):
# 调用去噪函数,并转换结果为浮点型
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, i, idx).to(torch.float32)
# 如果达到最后一步,返回去噪结果
if i == self.num_steps - 1:
if return_denoised:
return denoised, denoised
return denoised
# 计算当前步骤的 a_t 值
a_t = ((1-next_alpha_cumprod_sqrt**2)/(1-alpha_cumprod_sqrt**2))**0.5
# 计算当前步骤的 b_t 值
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
# 更新 x 的值,结合去噪后的结果
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
# 根据需要返回去噪结果
if return_denoised:
return x, denoised
return x
# 定义可调用函数,用于处理采样和去噪流程
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
# 准备采样循环所需的输入数据
x, s_in, alpha_cumprod_sqrts, num_sigmas, cond, uc, indices = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
# 根据 sigma 生成器逐步执行采样
for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step(
s_in * alpha_cumprod_sqrts[i],
s_in * alpha_cumprod_sqrts[i + 1],
denoiser,
x,
cond,
uc,
i=i,
idx=indices[self.num_steps-i-1],
)
# 返回最终的结果
return x