文章目录
一、 Unet整体结构
Unet模型的输入包含三个部分:
- 大小为 [B, C, H, W] 的图像 image; 注意不用在意表示大小时所用的符号,应将它们视作接口,比如 UNetModel接收大小为 [B, Z, H/8, W/8] 的 noise latent image 作为输入时,这里的 C 就等于 Z, H 就等于H/8, W 就等于 W/8;
- 大小为 [B,] 的 timesteps;
- 大小为 [B, K, E] 的文本 embedding 表示 context, 其中 K 表示最大编码长度,E 表示 embedding大小。
二、Unet源码详细解读
1.参数初始化
1.1默认参数赋值
参数是根据模型的配置而定,路径位于configs/stable-diffusion/v2-inference-v.yaml,截图如下。
由于配置中没有提及的参数会用默认值,为了分析代码的便利,以下打印出默认配置下unet的参数,以便在分析代码时作为参考,至少可以知道代码会走哪些分支、不走哪些分支。
class UNetModel(nn.Module):
def __init__(
self = UNetModel()
image_size = 32 # unused
in_channels = 4
model_channels = 320
out_channels = 4
num_res_blocks = 2
attention_resolutions = [4, 2, 1]
dropout = 0
channel_mult = [1, 2, 4, 4]
conv_resample = True
dims = 2
num_classes = None
use_checkpoint = True
use_fp16 = False
use_bf16 = False
num_heads = -1
num_head_channels = 64
num_heads_upsample = -1
use_scale_shift_norm = False
resblock_updown = False
use_new_attention_order = False
use_spatial_transformer = True
transformer_depth = 1
context_dim = 1024
n_embed = None
legacy = False
disable_self_attentions = None
num_attention_blocks = None
disable_middle_self_attn = False
use_linear_in_transformer = True
adm_in_channels = None
):
1.2 非法参数值的判断
例如:如果use_spatial_transformer是否为真,并且当context_dim不为空时,才会继续执行。若context_dim为空,则会抛出断言错误,提示用户需指定交叉注意力机制的条件维度。
if use_spatial_transformer: #True context_dim=1024
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if