stable diffusion中的UNet2DConditionModel代码解读

UNet2DConditionModel总体结构
在这里插入图片描述
图片来自于 https://zhuanlan.zhihu.com/p/635204519

stable diffusion 运行unet部分的代码。

noise_pred = self.unet(
    sample=latent_model_input,  #(2,4,64,64) 生成的latent
    timestep=t,  #时刻t
    encoder_hidden_states=prompt_embeds, #(2,77,768) #输入的prompt和negative prompt 生成的embedding
    timestep_cond=timestep_cond,#默认空
    cross_attention_kwargs=self.cross_attention_kwargs, #默认空
    added_cond_kwargs=added_cond_kwargs, #默认空
    return_dict=False,
)[0]

1.time

get_time_embed使用了sinusoidal timestep embeddings,
time_embedding 使用了两个线性层和激活层进行映射,将320转换到1280。
如果还有class_labels,added_cond_kwargs等参数,也转换为embedding,并且相加。

t_emb = self.get_time_embed(sample=sample, timestep=timestep)  #(2,320)
emb = self.time_embedding(t_emb, timestep_cond)  #(2,1280)

2.pre-process

卷积转换,输入latent从(2,4,64,64) 到(2,320,64,64)

sample = self.conv_in(sample)  #(2,320,64,64)
self.conv_in = nn.Conv2d(
            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
        )

3.down

down_block 由三个CrossAttnDownBlock2D和一个DownBlock2D组成。输入包括:

  • hidden_states:latent

  • temb:时刻t的embdedding

  • encoder_hidden_states:prompt和negative prompt的embedding

网络结构

CrossAttnDownBlock2D( 
    ResnetBlock2D()
    Transformer2DModel()
    ResnetBlock2D()
    Transformer2DModel()   
    Downsample2D()  #(2,320,32,32)
)
CrossAttnDownBlock2D( 
    ResnetBlock2D()
    Transformer2DModel()
    ResnetBlock2D()
    Transformer2DModel()   
    Downsample2D()  #(2,640,16,16)
)
CrossAttnDownBlock2D( 
    ResnetBlock2D()
    Transformer2DModel()
    ResnetBlock2D()
    Transformer2DModel()   
    Downsample2D()  #(2,1280,8,8)
)
DownBlock2D(
    ResnetBlock2D()
    ResnetBlock2D()  #(2,1280,8,8)
)

4.mid

UNetMidBlock2DCrossAttn 包含 resnet,attn,resnet三个模块,输入输出维度不变。输入包括:

  • hidden_states:latent
  • temb,时刻t的embdedding
  • encoder_hidden_states:prompt和negative prompt的embedding
UNetMidBlock2DCrossAttn(
    ResnetBlock2D()
    Transformer2DModel()
    ResnetBlock2D()
 )

5.up

up由一个UpBlock2D和三个CrossAttnUpBlock2D组成,输入包括:

  • hidden_states:latent
  • temb: 时刻t的embdedding
  • encoder_hidden_states:prompt和negative prompt的embedding
  • res_hidden_states_tupleL:下采样时每个block的结果,skip connection。
UpBlock2D(
    ResnetBlock2D()
    ResnetBlock2D()
    ResnetBlock2D()
    Upsample2D()  #(2,1280,16,16)
)
CrossAttnUpBlock2D( 
    ResnetBlock2D()
    Transformer2DModel()
    ResnetBlock2D()
    Transformer2DModel()   
    ResnetBlock2D()
    Transformer2DModel() 
    Downsample2D()  #(2,1280,32,32)
)
CrossAttnUpBlock2D( 
    ResnetBlock2D()
    Transformer2DModel()
    ResnetBlock2D()
    Transformer2DModel()   
    ResnetBlock2D()
    Transformer2DModel() 
    Downsample2D()  #(2,640,64,64)
)
CrossAttnUpBlock2D( 
    ResnetBlock2D() #(2,320,64,64)
    Transformer2DModel() 
    ResnetBlock2D()
    Transformer2DModel()   
    ResnetBlock2D()
    Transformer2DModel() 
)  

6.post-process

卷积变换通道数,得到最终结果

 if self.conv_norm_out:
     sample = self.conv_norm_out(sample)
     sample = self.conv_act(sample)
 sample = self.conv_out(sample) #(2,4,64,64)

时刻t,类别class等参数作用在resnet部分,都是和输入直接相加。
由prompt,negative prompt 计算得到的encoder_hidden_states,作用在attention部分,作为key和value,参与计算。

ResnetBlock2D

x在标准化、激活、卷积之后,和temb相加,再次标准化、激活、卷积之后作为残差,与x相加。

hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states) #激活函数
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
    if not self.skip_time_act:
        temb = self.nonlinearity(temb)
    temb = self.time_emb_proj(temb)[:, :, None, None] #(2,320,1,1)

if self.time_embedding_norm == "default":
    if temb is not None:
        hidden_states = hidden_states + temb  #与temb相加
    hidden_states = self.norm2(hidden_states)             
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

return output_tensor

Transformer2DModel attentions部分

每个attention 包括 Self-Attention 和Cross-Attention两部分。

#Self-Attention ,encoder_hidden_states=None
attn_output = self.attn1(
    norm_hidden_states,
    encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
    attention_mask=attention_mask,
    **cross_attention_kwargs,
       )

#Cross-Attention,encoder_hidden_states由prompt计算得来,在这里和latent交互。
attn_output = self.attn2(
    norm_hidden_states,
    encoder_hidden_states=encoder_hidden_states,
    attention_mask=encoder_attention_mask,
    **cross_attention_kwargs,
)

#query由norm_hidden_states计算而来,
#key、value由encoder_hidden_states计算而来。
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  #(2,8,4096,40)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  #(2,8,77,40)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) #(2,8,77,40)
hidden_states = F.scaled_dot_product_attention(
    query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False  #(2,8,4096,40)
)

参考:stable diffusion 中使用的 UNet 2D Condition Model 结构解析(diffusers库)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值