Stable_diffusion 2.1 源码详解——Unet部分


一、 Unet整体结构

Unet模型的输入包含三个部分

  • 大小为 [B, C, H, W] 的图像 image; 注意不用在意表示大小时所用的符号,应将它们视作接口,比如 UNetModel接收大小为 [B, Z, H/8, W/8] 的 noise latent image 作为输入时,这里的 C 就等于 Z, H 就等于H/8, W 就等于 W/8;
  • 大小为 [B,] 的 timesteps;
  • 大小为 [B, K, E] 的文本 embedding 表示 context, 其中 K 表示最大编码长度,E 表示 embedding大小。
    在这里插入图片描述

二、Unet源码详细解读

1.参数初始化

1.1默认参数赋值

  参数是根据模型的配置而定,路径位于configs/stable-diffusion/v2-inference-v.yaml,截图如下。
在这里插入图片描述  由于配置中没有提及的参数会用默认值,为了分析代码的便利,以下打印出默认配置下unet的参数,以便在分析代码时作为参考,至少可以知道代码会走哪些分支、不走哪些分支。

class UNetModel(nn.Module):
	def __init__(
		self = UNetModel()
		image_size = 32 # unused
		in_channels = 4
		model_channels = 320
		out_channels = 4
		num_res_blocks = 2
		attention_resolutions = [4, 2, 1]
		dropout = 0
		channel_mult = [1, 2, 4, 4]
		conv_resample = True
		dims = 2
		num_classes = None
		use_checkpoint = True
		use_fp16 = False
		use_bf16 = False
		num_heads = -1
		num_head_channels = 64
		num_heads_upsample = -1
		use_scale_shift_norm = False
		resblock_updown = False
		use_new_attention_order = False
		use_spatial_transformer = True
		transformer_depth = 1
		context_dim = 1024
		n_embed = None
		legacy = False
		disable_self_attentions = None
		num_attention_blocks = None
		disable_middle_self_attn = False
		use_linear_in_transformer = True
		adm_in_channels = None
		):

1.2 非法参数值的判断

例如:如果use_spatial_transformer是否为真,并且当context_dim不为空时,才会继续执行。若context_dim为空,则会抛出断言错误,提示用户需指定交叉注意力机制的条件维度。

if use_spatial_transformer: #True context_dim=1024
	assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
	assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
    from omegaconf.listconfig import ListConfig
    if type(context_dim) == ListConfig:
       context_dim = list(context_dim)

1.3 变量赋值

变量的赋值和输入参数基本一致。

self.attention_resolutions = attention_resolutions # [4, 2, 1]
self.dropout = dropout # 0
self.channel_mult = channel_mult # channel_mult = [1, 2, 4, 4]
self.conv_resample = conv_resample # True
self.num_classes =  num_classes # none
self.use_checkpoint = use_checkpoint # true
self.dtype = th.float16 if use_fp16 else th.float32
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.num_heads = num_heads # -1
self.num_head_channels = num_head_channels # 64
self.num_heads_upsample = num_heads_upsample # -1
self.predict_codebook_ids = n_embed is not None # false

2.生成timestep对应的embedding

2.1 时间编码的模块self.time_embed

在这里插入图片描述

  图中红框是源码中emb的生成的地方,emb是Unet中Resblock的输入,首先timesteps的初始化形状为(2,),经过timestep_embedding后生成timestep对应的embedding,形状为(2,320),在经过self.time_embed后,变成形状为(2,320)的输入。

t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)

(Pdb) timesteps.shape
torch.Size([2])
(Pdb) t_emb.shape
torch.Size([2, 320])

接下来是网络结构
在这里插入图片描述

  接下来是定义时间编码的模块self.time_embed,可见它是一个序列,包含[linear、nn.silu、linear]三个块,也就是MLP层、经过一个silu激活函数、再经过一个MLP线性层。看到第一个线性层的输入是model_channels(既320),输出为time_embed_dim,即为1280(time_embed_dim = 320*4 = 1280)。

time_embed_dim = model_channels * 4 # 1280
self.time_embed = nn.Sequential(
    linear(model_channels, time_embed_dim),
    nn.SiLU(),
    linear(time_embed_dim, time_embed_dim),
)

2.2 timestep_embedding

def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
    # 根据是否仅重复,选择不同的嵌入方式
    if not repeat_only:
        # 计算维度的一半,用于后续频率的计算
        half = dim // 2
        # 创建一个频率序列freqs,它是由0到half之间的等差数列经过指数运算得到的。这些频率值用于后续生成周期信号。
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=timesteps.device)
        # 将timesteps转换为浮点型张量,并增加一个新的维度以便广播,然后与freqs相乘,生成每个时间步对应的不同频率下的参数。
        args = timesteps[:, None].float() * freqs[None]
        # 使用三角函数生成嵌入向量,通过拼接cos和sin实现
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        # 如果dim是奇数,则在嵌入向量的末尾追加一个全零向量,其大小与嵌入的第一个元素相同,以保证最终嵌入向量的维度为dim。
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    else:
        # 如果仅重复,直接在维度d上重复时间步,生成嵌入向量
        embedding = repeat(timesteps, 'b -> b d', d=dim)
    # 返回生成的嵌入向量
    return embedding

3.Unet的构建过程

3.1 input_block添加ResBlock和transformer等构成Unet

  在unet的构建中,以input_block为例,程序会先构造一个名为layers的列表,给它添加ResBlock或SpatialTransformer模块,之后把这个列表通过append进self.input_blocks,在append之前先用TimestepEmbedSequential做包装,使得ResBlock或SpatialTransformer模块都能接受时间步emb或者文本条件context的参数。
  input_blocks部分,即对应如下图红框内的模块,下图中给标出了A、B1、B2、B3、B4几个部分。
在这里插入图片描述  首先,self.input_blocks定义为一个列表(nn.ModuleList),第一项内容是又TimestepEmbedSequential包装的conv_nd,对应于上图的A部分。

self.input_blocks = nn.ModuleList(
    [
        TimestepEmbedSequential(
            conv_nd(dims, in_channels, model_channels, 3, padding=1)
        )
    ]
)

  TimestepEmbedSequential,只是个包装函数,可见,它只是把参数emb或者context传入到内部每个模块的最后一个参数中。由于nn.Conv2d不继承自TimestepBlock或SpatialTransformer,因此这里它会走else分支,而conv_nd也不是数组(相当于长度为1),那么经过TimestepEmbedSequential包装的conv_nd即相当于直接调用conv_nd。
在这里插入图片描述  接着会看到如下的两层循环,这两层循环对应着示意图中B1、B2、B3、B4几个模块处理,也是比较复杂的部分。
在这里插入图片描述把channel_mult和self.num_res_blocks的默认值代入,得到如下的代码:

for level, mult in enumerate( [1, 2, 4, 4]):
     	for nr in range([2, 2, 2, 2][level]):

因此这里的循环的值nr即为:

1次外圈: [0,1]2次外圈: [0,1]3次外圈: [0,1]4次外圈: [0,1] 

那么外层循环(红框)总共也有4次,内层循环(蓝框)各两个,至于B4形状与B1、B2、B3不同,是因为最后一次外层循环会有特殊处理。
在这里插入图片描述每一次内循环结构的模块现在看看每一次内循环的实现,我们一边看下面的图一边对应代码,每一次内循环对应的是蓝框里的内容,每一次外循环对应的是红框里的内容。循环开始,先是定义个layers的数组,然后添加ResBlock模块。
在这里插入图片描述  紧接着是做一些判断,如果条件符合就添加(layers.append)一个AttentionBlock或SpatialTransformer,添加的模块取决于use_spatial_transformer的值,前面分析到这个值为True,因此添加的是SpatialTransformer。
在这里插入图片描述  接着我们看上述代码中添加SpatialTransformer模块的条件,有两个,一个是“if ds in attention_resolutions”,另一个是“if not exists(num_attention_blocks) or …”。第一个条件中,attention_resolutions的值为 [4, 2, 1],而ds的值会在每次外循环中做“if level !=len(channel_mult)-1”的判断,如果满足条件就乘以2,见如下代码。
在这里插入图片描述  在每次内循环结束,且外循环不是最后一次循环时,ds会乘以2,那么各个循环中ds的值分别为:

1次外圈: level=0 内圈nr->[0,1]    ds->12次外圈: level=1 内圈nr->[0,1]    ds->23次外圈: level=2 内圈nr->[0,1]    ds->44次外圈: level=3 内圈nr->[0,1]    ds->8   

  前3次外循环“if ds in attention_resolutions”这个条件都成立。第二条件是:if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
在默认配置下num_attention_blocks的值为none,因此第二个条件也成立,即前3次外循环的每次内循环都会加上ResBlock和SpatialTransformer,而第4次外循环只会加上ResBlock.

3.2 ResBlock

  ResBlock继承自TimestepBlock,那么在TimestepEmbedSequential的包装下,它会走到”if isinstance(layer, TimestepBlock)的分支“,对应的处理方法是x = layer(x,emb),即把时间步编码传入到ResBlock中。
在这里插入图片描述

self.in_layers = nn.Sequential(
    normalization(channels),
    nn.SiLU(),
    conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.emb_layers = nn.Sequential(
    nn.SiLU(),
    linear(
        emb_channels,
        2 * self.out_channels if use_scale_shift_norm else self.out_channels,
    ),
)
self.out_layers = nn.Sequential(
    normalization(self.out_channels),
    nn.SiLU(),
    nn.Dropout(p=dropout),
    zero_module(
        conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
    ),
)

分析完init构建,再来看看ResBlock的前向过程forward,代码如下。它的组织结构与示意图相同,将输入x经过in_layer得到h,也将输入emb经过emb_layers得到emb_out,将两者相加得到新的变量h,而后经过out_layer和skip connection输出。
在这里插入图片描述

3.3 SpatialTransformer模块

SpatialTransformer的示意图如下所示,图中可见它有latent in和context embeddings两个输入,latent in经过PaddelConv得到一个新变量,而后就是正常的注意力机制流程。其中Lattent in这一支是作为查询Q,Context Embeddings这一支是作为键和值KV。
在这里插入图片描述
接下来看看初始化构建的代码,如下所示,与图中对应,代码中定义了self.proj_in、transformer_blocks、proj_out一些模块,其中proj_in是一个卷积网络、而transformer_blocks是attation实现的核心,内层是BasicTransformerBlock,由于默认参数下depth的值为1,因此transformer_blocks中只会有一个BasicTransformerBlock。
在这里插入图片描述再看看前向过程forward,可以发现对于输入的隐变量x,经过proj_in后会将它变形,成为符合attention模块的输入,而文本条件context直接传入到BasicTransformerBlock模块中。最后通过proj_out输出。
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值