文章目录
一、 Unet整体结构
Unet模型的输入包含三个部分:
- 大小为 [B, C, H, W] 的图像 image; 注意不用在意表示大小时所用的符号,应将它们视作接口,比如 UNetModel接收大小为 [B, Z, H/8, W/8] 的 noise latent image 作为输入时,这里的 C 就等于 Z, H 就等于H/8, W 就等于 W/8;
- 大小为 [B,] 的 timesteps;
- 大小为 [B, K, E] 的文本 embedding 表示 context, 其中 K 表示最大编码长度,E 表示 embedding大小。
二、Unet源码详细解读
1.参数初始化
1.1默认参数赋值
参数是根据模型的配置而定,路径位于configs/stable-diffusion/v2-inference-v.yaml,截图如下。
由于配置中没有提及的参数会用默认值,为了分析代码的便利,以下打印出默认配置下unet的参数,以便在分析代码时作为参考,至少可以知道代码会走哪些分支、不走哪些分支。
class UNetModel(nn.Module):
def __init__(
self = UNetModel()
image_size = 32 # unused
in_channels = 4
model_channels = 320
out_channels = 4
num_res_blocks = 2
attention_resolutions = [4, 2, 1]
dropout = 0
channel_mult = [1, 2, 4, 4]
conv_resample = True
dims = 2
num_classes = None
use_checkpoint = True
use_fp16 = False
use_bf16 = False
num_heads = -1
num_head_channels = 64
num_heads_upsample = -1
use_scale_shift_norm = False
resblock_updown = False
use_new_attention_order = False
use_spatial_transformer = True
transformer_depth = 1
context_dim = 1024
n_embed = None
legacy = False
disable_self_attentions = None
num_attention_blocks = None
disable_middle_self_attn = False
use_linear_in_transformer = True
adm_in_channels = None
):
1.2 非法参数值的判断
例如:如果use_spatial_transformer是否为真,并且当context_dim不为空时,才会继续执行。若context_dim为空,则会抛出断言错误,提示用户需指定交叉注意力机制的条件维度。
if use_spatial_transformer: #True context_dim=1024
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
1.3 变量赋值
变量的赋值和输入参数基本一致。
self.attention_resolutions = attention_resolutions # [4, 2, 1]
self.dropout = dropout # 0
self.channel_mult = channel_mult # channel_mult = [1, 2, 4, 4]
self.conv_resample = conv_resample # True
self.num_classes = num_classes # none
self.use_checkpoint = use_checkpoint # true
self.dtype = th.float16 if use_fp16 else th.float32
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.num_heads = num_heads # -1
self.num_head_channels = num_head_channels # 64
self.num_heads_upsample = num_heads_upsample # -1
self.predict_codebook_ids = n_embed is not None # false
2.生成timestep对应的embedding
2.1 时间编码的模块self.time_embed
图中红框是源码中emb的生成的地方,emb是Unet中Resblock的输入,首先timesteps的初始化形状为(2,),经过timestep_embedding后生成timestep对应的embedding,形状为(2,320),在经过self.time_embed后,变成形状为(2,320)的输入。
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
(Pdb) timesteps.shape
torch.Size([2])
(Pdb) t_emb.shape
torch.Size([2, 320])
接下来是网络结构
接下来是定义时间编码的模块self.time_embed,可见它是一个序列,包含[linear、nn.silu、linear]三个块,也就是MLP层、经过一个silu激活函数、再经过一个MLP线性层。看到第一个线性层的输入是model_channels(既320),输出为time_embed_dim,即为1280(time_embed_dim = 320*4 = 1280)。
time_embed_dim = model_channels * 4 # 1280
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
2.2 timestep_embedding
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
# 根据是否仅重复,选择不同的嵌入方式
if not repeat_only:
# 计算维度的一半,用于后续频率的计算
half = dim // 2
# 创建一个频率序列freqs,它是由0到half之间的等差数列经过指数运算得到的。这些频率值用于后续生成周期信号。
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
# 将timesteps转换为浮点型张量,并增加一个新的维度以便广播,然后与freqs相乘,生成每个时间步对应的不同频率下的参数。
args = timesteps[:, None].float() * freqs[None]
# 使用三角函数生成嵌入向量,通过拼接cos和sin实现
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# 如果dim是奇数,则在嵌入向量的末尾追加一个全零向量,其大小与嵌入的第一个元素相同,以保证最终嵌入向量的维度为dim。
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
# 如果仅重复,直接在维度d上重复时间步,生成嵌入向量
embedding = repeat(timesteps, 'b -> b d', d=dim)
# 返回生成的嵌入向量
return embedding
3.Unet的构建过程
3.1 input_block添加ResBlock和transformer等构成Unet
在unet的构建中,以input_block为例,程序会先构造一个名为layers的列表,给它添加ResBlock或SpatialTransformer模块,之后把这个列表通过append进self.input_blocks,在append之前先用TimestepEmbedSequential做包装,使得ResBlock或SpatialTransformer模块都能接受时间步emb或者文本条件context的参数。
input_blocks部分,即对应如下图红框内的模块,下图中给标出了A、B1、B2、B3、B4几个部分。
首先,self.input_blocks定义为一个列表(nn.ModuleList),第一项内容是又TimestepEmbedSequential包装的conv_nd,对应于上图的A部分。
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
TimestepEmbedSequential,只是个包装函数,可见,它只是把参数emb或者context传入到内部每个模块的最后一个参数中。由于nn.Conv2d不继承自TimestepBlock或SpatialTransformer,因此这里它会走else分支,而conv_nd也不是数组(相当于长度为1),那么经过TimestepEmbedSequential包装的conv_nd即相当于直接调用conv_nd。
接着会看到如下的两层循环,这两层循环对应着示意图中B1、B2、B3、B4几个模块处理,也是比较复杂的部分。
把channel_mult和self.num_res_blocks的默认值代入,得到如下的代码:
for level, mult in enumerate( [1, 2, 4, 4]):
for nr in range([2, 2, 2, 2][level]):
因此这里的循环的值nr即为:
第1次外圈: [0,1]
第2次外圈: [0,1]
第3次外圈: [0,1]
第4次外圈: [0,1]
那么外层循环(红框)总共也有4次,内层循环(蓝框)各两个,至于B4形状与B1、B2、B3不同,是因为最后一次外层循环会有特殊处理。
每一次内循环结构的模块现在看看每一次内循环的实现,我们一边看下面的图一边对应代码,每一次内循环对应的是蓝框里的内容,每一次外循环对应的是红框里的内容。循环开始,先是定义个layers的数组,然后添加ResBlock模块。
紧接着是做一些判断,如果条件符合就添加(layers.append)一个AttentionBlock或SpatialTransformer,添加的模块取决于use_spatial_transformer的值,前面分析到这个值为True,因此添加的是SpatialTransformer。
接着我们看上述代码中添加SpatialTransformer模块的条件,有两个,一个是“if ds in attention_resolutions”,另一个是“if not exists(num_attention_blocks) or …”。第一个条件中,attention_resolutions的值为 [4, 2, 1],而ds的值会在每次外循环中做“if level !=len(channel_mult)-1”的判断,如果满足条件就乘以2,见如下代码。
在每次内循环结束,且外循环不是最后一次循环时,ds会乘以2,那么各个循环中ds的值分别为:
第1次外圈: level=0 内圈nr->[0,1] ds->1
第2次外圈: level=1 内圈nr->[0,1] ds->2
第3次外圈: level=2 内圈nr->[0,1] ds->4
第4次外圈: level=3 内圈nr->[0,1] ds->8
前3次外循环“if ds in attention_resolutions”这个条件都成立。第二条件是:if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
在默认配置下num_attention_blocks的值为none,因此第二个条件也成立,即前3次外循环的每次内循环都会加上ResBlock和SpatialTransformer,而第4次外循环只会加上ResBlock.
3.2 ResBlock
ResBlock继承自TimestepBlock,那么在TimestepEmbedSequential的包装下,它会走到”if isinstance(layer, TimestepBlock)的分支“,对应的处理方法是x = layer(x,emb),即把时间步编码传入到ResBlock中。
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)
分析完init构建,再来看看ResBlock的前向过程forward,代码如下。它的组织结构与示意图相同,将输入x经过in_layer得到h,也将输入emb经过emb_layers得到emb_out,将两者相加得到新的变量h,而后经过out_layer和skip connection输出。
3.3 SpatialTransformer模块
SpatialTransformer的示意图如下所示,图中可见它有latent in和context embeddings两个输入,latent in经过PaddelConv得到一个新变量,而后就是正常的注意力机制流程。其中Lattent in这一支是作为查询Q,Context Embeddings这一支是作为键和值KV。
接下来看看初始化构建的代码,如下所示,与图中对应,代码中定义了self.proj_in、transformer_blocks、proj_out一些模块,其中proj_in是一个卷积网络、而transformer_blocks是attation实现的核心,内层是BasicTransformerBlock,由于默认参数下depth的值为1,因此transformer_blocks中只会有一个BasicTransformerBlock。
再看看前向过程forward,可以发现对于输入的隐变量x,经过proj_in后会将它变形,成为符合attention模块的输入,而文本条件context直接传入到BasicTransformerBlock模块中。最后通过proj_out输出。