Stable_diffusion 2.1 源码详解——Unet部分


一、 Unet整体结构

Unet模型的输入包含三个部分

  • 大小为 [B, C, H, W] 的图像 image; 注意不用在意表示大小时所用的符号,应将它们视作接口,比如 UNetModel接收大小为 [B, Z, H/8, W/8] 的 noise latent image 作为输入时,这里的 C 就等于 Z, H 就等于H/8, W 就等于 W/8;
  • 大小为 [B,] 的 timesteps;
  • 大小为 [B, K, E] 的文本 embedding 表示 context, 其中 K 表示最大编码长度,E 表示 embedding大小。
    在这里插入图片描述

二、Unet源码详细解读

1.参数初始化

1.1默认参数赋值

  参数是根据模型的配置而定,路径位于configs/stable-diffusion/v2-inference-v.yaml,截图如下。
在这里插入图片描述  由于配置中没有提及的参数会用默认值,为了分析代码的便利,以下打印出默认配置下unet的参数,以便在分析代码时作为参考,至少可以知道代码会走哪些分支、不走哪些分支。

class UNetModel(nn.Module):
	def __init__(
		self = UNetModel()
		image_size = 32 # unused
		in_channels = 4
		model_channels = 320
		out_channels = 4
		num_res_blocks = 2
		attention_resolutions = [4, 2, 1]
		dropout = 0
		channel_mult = [1, 2, 4, 4]
		conv_resample = True
		dims = 2
		num_classes = None
		use_checkpoint = True
		use_fp16 = False
		use_bf16 = False
		num_heads = -1
		num_head_channels = 64
		num_heads_upsample = -1
		use_scale_shift_norm = False
		resblock_updown = False
		use_new_attention_order = False
		use_spatial_transformer = True
		transformer_depth = 1
		context_dim = 1024
		n_embed = None
		legacy = False
		disable_self_attentions = None
		num_attention_blocks = None
		disable_middle_self_attn = False
		use_linear_in_transformer = True
		adm_in_channels = None
		):

1.2 非法参数值的判断

例如:如果use_spatial_transformer是否为真,并且当context_dim不为空时,才会继续执行。若context_dim为空,则会抛出断言错误,提示用户需指定交叉注意力机制的条件维度。

if use_spatial_transformer: #True context_dim=1024
	assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
	assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
    from omegaconf.listconfig import ListConfig
    if 
### Stable Diffusion 项目代码实现 #### 导入必要的库 为了构建和训练Stable Diffusion模型,首先需要导入一系列必需的Python库。 ```python import torch from torch import nn, optim from torchvision import datasets, transforms from transformers import CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler ``` #### 初始化组件 根据描述,Stable Diffusion主要由三个部分构成:文本编码器、U-Net以及变分自编码器(VAE)[^2]。以下是初始化这些模块的方法: ##### 文本编码器 用于处理输入的文字提示(prompt),将其转换成向量表示形式供后续阶段使用。 ```python tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") ``` ##### U-Net网络结构 负责逐步去除噪声并生成最终图片的核心机制。 ```python unet = UNet2DConditionModel.from_pretrained( "CompVis/stable-diffusion-v1-4", subfolder="unet" ) ``` ##### 变分自编码器(VAE) 用来压缩原始图像数据,并在解码过程中重建高质量输出。 ```python vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse") ``` #### 设置优化策略和其他辅助函数 定义损失计算方式以及其他有助于提高训练效果的技术细节。 ```python scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") def train_step(input_ids, pixel_values): with torch.no_grad(): latents = vae.encode(pixel_values).latent_dist.sample().detach() optimizer.zero_grad() noise_pred = unet(latents, input_ids=input_ids).sample loss = F.mse_loss(noise_pred, target_latent) loss.backward() optimizer.step() return loss.item() ``` 以上展示了如何基于PyTorch框架搭建一个简易版本的Stable Diffusion工作流[^2]。需要注意的是,在实际应用中还需要考虑更多因素如超参数调整、硬件加速支持等。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值