教程来自Diffusion models代码解读:入门与实战
# 1. Network helpers
# 定义辅助函数
#函数exists接受参数x并检查是不是None, 使用is操作符来检查x是否和None是同一个对象, is操作符比较的是对象的身份(即它们是否指向内存中的同一个位置)
def exists(x):
return x is not None
#在 val 存在(即不是 None)时返回 val,否则返回 d 的值。但是,这里 d 的处理有些特殊:如果 d 是一个可调用的对象(比如函数或类实例的 __call__ 方法),则调用它并返回结果;如果 d 不是可调用的,则直接返回 d 的值。
def default(val, d):
if exists(val):
return val
return d() if isinstance(d) else d
# 定义残差,添加到特定函数的残差连接
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
# 定义上采样
def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, stride=2, padding=1)
#定义下采样
def Downsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, stride=2, padding=1)
# 2. Position embeddings
# 神经网络的参数跨时间共享采用sinusoidal position embeddings编码时间time。在批量处理图像时,使得神经网络知道在特定时间的步长操作。
class SinusoidalPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device #获取输入张量time所在的设备(CPU或GPU),以便在后续操作中确保所有张量都在相同的设备上。
half_dim = self.dim // 2 #嵌入的维度一半正弦一半余弦
embeddings = math.log(10000) / (half_dim - 1) #计算缩放因子,保证嵌入的值在合理范围内
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) #使用torch.arange生成一个从0到half_dim-1的整数序列,并将该序列乘以缩放因子的负值,然后exp取指数,得到一个指数递减的序列。
embeddings = time[:, None] * embeddings[:, None] #上一步的序列与time相乘得到基于位置的嵌入
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) #生成正弦和余弦的嵌入
# 为什么基于时间做位置嵌入之后还要生成正弦和余弦的嵌入?正弦和余弦可以表征不同位置间的依赖关系,捕捉绝对位置和相对位置,提高长序处理能力
return embeddings
运用U-net预测噪声更好的学习到数据的增广分布(数据增广:对原始数据集进行一系列变换丰富数据集,提高模型训练效果),缺点是训练复杂模型复杂度增加
# 3. ResNet/ConvNeXT block
# 构造U-Net的核心模块
class Block(nn.Module)
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
"""https://arxiv.org/abs/1512.03385 """
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
super().__init__()
self.mlp(
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
if exists(time_emb_dim)
else None
)
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=grpups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None):
h = self.block1(x)
if exists(self, mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
h = rearrange(time_emb, "b c -> b c 1 1") + h
h = self.block2(h)
return h + self.res_conv(x)
#4. Attention Model 添加到卷积模块当中
class Attention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32, dropout=0.):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, kernel_size=1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, kernel_size=1, bias=False)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)#qkv是包含了q、k和v信息的张量,map将lambda t应用于qkv中的q、k和v,lambda定义了对t的重新排列操作,b=batchsize, h=head number, c=channel number, x、y空间维度
q = q * self.scale
sim = einsum("b h d i, b h d j -> b h i j", q, k) #einsum用于计算多维数组之间的元素及乘法和求和。q 和 k 的形状分别是 (b, h, d, i) 和 (b, h, d, j),对于每一个 (b, h, d) 的组合,einsum 会对 q 的 i 维度和 k 的 j 维度执行点积(即元素级乘法后的求和)。结果是一个形状为 (b, h, i, j) 的张量。
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("b h i j, b h d j -> b h i d", attn, v)
out = arrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
return self.to_out(out)
#考虑到复杂度问题,这里建议用linear attention
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32, dropout=0.):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, kernel_size=1, bias=False)
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)
q = q * self.scale
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
return self.to_out(out)
#5. Group normalization 放在attention之前
class Prenorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(1, dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
Conditional U-Net
(1)卷积计算噪声position embeddings
(2)下采样:ResNet block
∗
2
*2
∗2 + groupnorm + attention + residual connection + downsample
(3)中间应用ResNet block和attention
(4)上采样:ResNet block
∗
2
*2
∗2 + groupnorm + attention + residual connection + upsample
(5)ResNet block
#6. Conditional U-Net
class Unet(nn.Module):
def __init__(
self,
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
with_time_emb=True,
resnet_block_groups=8,
use_convnext=True,
convnext_mult=2,
):
super().__init__()
# determine dimensions
self.channels = channels
init_dim = default(init_dim, dim // 3 * 2)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
if use_convnext:
block_klass = partial(ConvNextBlock, mult=convnext_mult)
else:
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
# time embeddings
if with_time_emb:
time_dim = dim * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(dim),
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim),
)
else:
time_dim = None
self.time_mlp = None
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(
nn.ModuleList(
[
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity(),
]
)
)
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(
nn.ModuleList(
[
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity(),
]
)
)
out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential(
block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
)
def forward(self, x, time):
x = self.init_conv(x)
t = self.time_mlp(time) if exists(self.time_mlp) else None
h = []
# downsample
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
# bottleneck
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
# upsample
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)
return self.final_conv(x)