Latent Diffusion 学习笔记和实例

参考

根据 github 项目 的学习笔记,网络结构与原作者项目相同,但是根据博主的习惯对代码进行了些许改动。原作者项目的链接如下:

网站链接
Githubhttps://github.com/lansinuote/Diffusion_From_Scratch
Huggingfacehttps://huggingface.co/datasets/lansinuote/diffsion_from_scratch

这篇博文的项目地址

网站链接
Githubhttps://github.com/MarcYugo/a-practice-example-latent-diffusion
GitCodehttps://gitcode.com/weixin_43385826/a-practice-example-latent-diffusion/

依赖

除了 Pytorch 外还需要下面几个包。

transformers==4.26.1
datasets==2.9.0
diffusers==0.12.1

整体结构

该神经网络模型包括三部分,抽取文本特征的文本编码器(Text Encoder)、抽取图像特征的图像编码器(Image Encoder)和 用于去噪的 UNet 网络(Diffusion Model),如 图1 所示。

latent_diffusion

图 1. Latent Diffusion 网络的整体结构

文本编码器 Text Encoder

Text Encoder 采用带有多头注意力的 Transformer 结构并使用 CLIP 预训练模型的参数。具体结构是包括:一个用于扩增维度和位置编码的Embedding层,12个相同结构的带掩码的多头自注意力(Multi-Head Self-Attentiom, MHSA)层 和最后一层 LayerNorm层。

'''
   用于提取文本特征的编码器
   text_encoder.py
'''
import torch
from torch import nn
from transformers import CLIPTextModel

class Embed(nn.Module):
    def __init__(self, embed_dim=768, n_tokens=77, seq_len=49408):
        super().__init__()
        self.embed = nn.Embedding(seq_len, embed_dim)
        self.pos_embed = nn.Embedding(n_tokens, embed_dim)
        self.embed_dim = embed_dim
        self.n_tokens = n_tokens
        self.register_buffer('pos_ids', torch.arange(n_tokens).unsqueeze(0))
    def forward(self, input_ids):
        # input_ids: (b, 77)
        embed = self.embed(input_ids)
        pos_embed = self.pos_embed(self.pos_ids)
        return embed + pos_embed

class SelfAttention(nn.Module):
    def __init__(self, emb_dim=768, heads=12):
        super().__init__()
        self.wq = nn.Linear(emb_dim, emb_dim)
        self.wk = nn.Linear(emb_dim, emb_dim)
        self.wv = nn.Linear(emb_dim, emb_dim)
        self.out_proj = nn.Linear(emb_dim, emb_dim)
        self.emb_dim = emb_dim
        self.heads = heads
    def get_mask(self, b, n_tok):
        mask = torch.empty(b, n_tok, n_tok)
        mask.fill_(-float('inf'))
        mask.triu_(1).unsqueeze(1)
        return mask
    def forward(self, x):
        # (b, 77, 768)
        b, n_tok, _ = x.shape
        q = self.wq(x)/8
        k = self.wk(x)
        v = self.wv(x)
        # 注意力头拆分
        q = q.reshape(b, n_tok, self.heads, self.emb_dim//self.heads).transpose(1,2).reshape(b*self.heads, n_tok, self.emb_dim//self.heads)
        k = k.reshape(b, n_tok, self.heads, self.emb_dim//self.heads).transpose(1,2).reshape(b*self.heads, n_tok, self.emb_dim//self.heads)
        v = v.reshape(b, n_tok, self.heads, self.emb_dim//self.heads).transpose(1,2).reshape(b*self.heads, n_tok, self.emb_dim//self.heads)
        # 计算q,k乘积, qk关系矩阵
        atten = torch.bmm(q, k.transpose(1,2))
        atten = atten.reshape(b, self.heads, n_tok, n_tok)
        atten = atten + self.get_mask(b, n_tok).to(atten.device)
        atten = atten.reshape(b*self.heads, n_tok, n_tok)
        atten = atten.softmax(dim=-1)
        atten = torch.bmm(atten, v) # (b*12, 77, 77)
        atten = atten.reshape(b, self.heads, n_tok, self.emb_dim//self.heads).transpose(1,2).reshape(b, n_tok, self.emb_dim) # (b, 77, 768)
        out = self.out_proj(atten)
        return out

class QuickGELU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x*(x*1.702).sigmoid()

class Block(nn.Module):
    def __init__(self, embed_dim=768, expand_dim=3072):
        super().__init__()
        self.seq1 = nn.Sequential(
            nn.LayerNorm(embed_dim),
            SelfAttention())
        self.seq2 = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, expand_dim),
            QuickGELU(),
            nn.Linear(expand_dim, embed_dim))
    def forward(self, x):
        x = x + self.seq1(x)
        x = x + self.seq2(x)
        return x

class TextEncoder(nn.Module):
    def __init__(self, embed_dim=768):
        super().__init__()
        seq = [Embed()] + [Block() for _ in range(12)] + [nn.LayerNorm(embed_dim)]
        self.seq = nn.Sequential(*seq)
    def forward(self, x):
        out = self.seq(x)
        return out

def load_pretrained(model):
    params = CLIPTextModel.from_pretrained('./pretrained-params', subfolder='text_encoder')
    model.seq[0].embed.load_state_dict(
        params.text_model.embeddings.token_embedding.state_dict())
    model.seq[0].pos_embed.load_state_dict(
        params.text_model.embeddings.position_embedding.state_dict())

    for i in range(12):
        model.seq[i+1].seq1[0].load_state_dict(
            params.text_model.encoder.layers[i].layer_norm1.state_dict())
        model.seq[i+1].seq1[1].wq.load_state_dict(
            params.text_model.encoder.layers[i].self_attn.q_proj.state_dict())
        model.seq[i+1].seq1[1].wk.load_state_dict(
            params.text_model.encoder.layers[i].self_attn.k_proj.state_dict())
        model.seq[i+1].seq1[1].wv.load_state_dict(
            params.text_model.encoder.layers[i].self_attn.v_proj.state_dict())
        model.seq[i+1].seq1[1].out_proj.load_state_dict(
            params.text_model.encoder.layers[i].self_attn.out_proj.state_dict())
        model.seq[i+1].seq2[0].load_state_dict(
            params.text_model.encoder.layers[i].layer_norm2.state_dict())
        model.seq[i+1].seq2[1].load_state_dict(
            params.text_model.encoder.layers[i].mlp.fc1.state_dict())
        model.seq[i+1].seq2[3].load_state_dict(
            params.text_model.encoder.layers[i].mlp.fc2.state_dict())
    
    model.seq[13].load_state_dict(params.text_model.final_layer_norm.state_dict())
    return model

def text_encoder_pretrained():
    text_encoder = TextEncoder()
    text_encoder = load_pretrained(text_encoder)
    return text_encoder

图像编码器 Image Encoder

Image Encoder 使用的模型是VAE(Variational Autoencoder),由 ResNet 块结构和 少量的 自注意力(Self-Attention)层堆叠构成 的一个Encoder 和一个 Decoder 组成。其中,所有标准化层和激活函数是 GroupNorm 标准化层 和 SiLU 激活函数。

'''
  用于提取视觉特征的编码器 VAE 
  vision_auto_encoder.py
'''
import torch
from torch import nn
from diffusers import AutoencoderKL

class ResNetBlock(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.seq = nn.Sequential(
            nn.GroupNorm(num_groups=32, num_channels=dim_in, eps=1e-6, affine=True),
            nn.SiLU(),
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(num_groups=32, num_channels=dim_out, eps=1e-6, affine=True),
            nn.SiLU(),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1))
        
        self.resil = nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1) if dim_in != dim_out else nn.Identity()
        self.dim_in = dim_in
        self.dim_out = dim_out
    def forward(self, x):
        res = self.resil(x)
        out = self.seq(x) + res
        return out

class SelfAttention(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        self.norm = nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6, affine=True)
        self.wq = torch.nn.Linear(embed_dim, embed_dim)
        self.wk = torch.nn.Linear(embed_dim, embed_dim)
        self.wv = torch.nn.Linear(embed_dim, embed_dim)
        self.out_proj = torch.nn.Linear(embed_dim, embed_dim)
    def forward(self, x):
        # x: (b, 512, 64, 64)
        res = x
        b,c,h,w = x.shape
        x = self.norm(x)
        x = x.flatten(start_dim=2).transpose(1,2) # (1, 4096, 512)
        q = self.wq(x) # (1, 4096, 512)
        k = self.wk(x)
        v = self.wv(x)
        k = k.transpose(1,2) # (1, 512, 4096
        #[1, 4096, 512] * [1, 512, 4096] -> [1, 4096, 4096]
        #0.044194173824159216 = 1 / 512**0.5
        atten = q.bmm(k) / 512**0.5

        atten = torch.softmax(atten, dim=2)
        atten = atten.bmm(v)
        atten = self.out_proj(atten)
        atten = atten.transpose(1, 2).reshape(b, c, h, w)
        atten = atten + res
        return atten

class Pad(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        return nn.functional.pad(x, (0, 1, 0, 1),
                                    mode='constant',
                                    value=0)

class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            #in
            nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1),
            #down
            nn.Sequential(
                ResNetBlock(128, 128),
                ResNetBlock(128, 128),
                nn.Sequential(
                    Pad(),
                    torch.nn.Conv2d(128, 128, 3, stride=2, padding=0),
                ),
            ),
            torch.nn.Sequential(
                ResNetBlock(128, 256),
                ResNetBlock(256, 256),
                torch.nn.Sequential(
                    Pad(),
                    nn.Conv2d(256, 256, 3, stride=2, padding=0),
                ),
            ),
            nn.Sequential(
                ResNetBlock(256, 512),
                ResNetBlock(512, 512),
                nn.Sequential(
                    Pad(),
                    nn.Conv2d(512, 512, 3, stride=2, padding=0),
                ),
            ),
            nn.Sequential(
                ResNetBlock(512, 512),
                ResNetBlock(512, 512),
            ),
            #mid
            nn.Sequential(
                ResNetBlock(512, 512),
                SelfAttention(),
                ResNetBlock(512, 512),
            ),
            #out
            nn.Sequential(
                nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6),
                nn.SiLU(),
                nn.Conv2d(512, 8, 3, padding=1),
            ),
            #正态分布层
            nn.Conv2d(8, 8, 1))

        self.decoder = nn.Sequential(
            #正态分布层
            nn.Conv2d(4, 4, 1),
            #in
            nn.Conv2d(4, 512, kernel_size=3, stride=1, padding=1),
            #middle
            nn.Sequential(ResNetBlock(512, 512), 
                                SelfAttention(), 
                                ResNetBlock(512, 512)),
            #up
            nn.Sequential(
                ResNetBlock(512, 512),
                ResNetBlock(512, 512),
                ResNetBlock(512, 512),
                nn.Upsample(scale_factor=2.0, mode='nearest'),
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
            ),
            nn.Sequential(
                ResNetBlock(512, 512),
                ResNetBlock(512, 512),
                ResNetBlock(512, 512),
                nn.Upsample(scale_factor=2.0, mode='nearest'),
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
            ),
            nn.Sequential(
                ResNetBlock(512, 256),
                ResNetBlock(256, 256),
                ResNetBlock(256, 256),
                nn.Upsample(scale_factor=2.0, mode='nearest'),
                nn.Conv2d(256, 256, kernel_size=3, padding=1),
            ),
            nn.Sequential(
                ResNetBlock(256, 128),
                ResNetBlock(128, 128),
                ResNetBlock(128, 128),
            ),
            #out
            nn.Sequential(
                nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6),
                nn.SiLU(),
                nn.Conv2d(128, 3, 3, padding=1),
            ))
    
    def sample(self, h):
        #h -> [1, 8, 64, 64]
        #[1, 4, 64, 64]
        mean = h[:, :4]
        logvar = h[:, 4:]
        std = logvar.exp()**0.5
        #[1, 4, 64, 64]
        h = torch.randn(mean.shape, device=mean.device)
        h = mean + std * h
        return h
    
    def forward(self, x):
        #x -> [1, 3, 512, 512]
        #[1, 3, 512, 512] -> [1, 8, 64, 64]
        h = self.encoder(x)
        #[1, 8, 64, 64] -> [1, 4, 64, 64]
        h = self.sample(h)
        #[1, 4, 64, 64] -> [1, 3, 512, 512]
        h = self.decoder(h)
        return h

def load_pretrained(model):
    params = AutoencoderKL.from_pretrained('./pretrained-params/', subfolder='vae')
    model.encoder[0].load_state_dict(params.encoder.conv_in.state_dict())
    #encoder.down
    for i in range(4):
        load_res(model.encoder[i + 1][0], params.encoder.down_blocks[i].resnets[0])
        load_res(model.encoder[i + 1][1], params.encoder.down_blocks[i].resnets[1])
        if i != 3:
            model.encoder[i + 1][2][1].load_state_dict(
                params.encoder.down_blocks[i].downsamplers[0].conv.state_dict())
    #encoder.mid
    load_res(model.encoder[5][0], params.encoder.mid_block.resnets[0])
    load_res(model.encoder[5][2], params.encoder.mid_block.resnets[1])
    load_atten(model.encoder[5][1], params.encoder.mid_block.attentions[0])
    #encoder.out
    model.encoder[6][0].load_state_dict(params.encoder.conv_norm_out.state_dict())
    model.encoder[6][2].load_state_dict(params.encoder.conv_out.state_dict())
    #encoder.正态分布层
    model.encoder[7].load_state_dict(params.quant_conv.state_dict())
    #decoder.正态分布层
    model.decoder[0].load_state_dict(params.post_quant_conv.state_dict())
    #decoder.in
    model.decoder[1].load_state_dict(params.decoder.conv_in.state_dict())
    #decoder.mid
    load_res(model.decoder[2][0], params.decoder.mid_block.resnets[0])
    load_res(model.decoder[2][2], params.decoder.mid_block.resnets[1])
    load_atten(model.decoder[2][1], params.decoder.mid_block.attentions[0])
    #decoder.up
    for i in range(4):
        load_res(model.decoder[i + 3][0], params.decoder.up_blocks[i].resnets[0])
        load_res(model.decoder[i + 3][1], params.decoder.up_blocks[i].resnets[1])
        load_res(model.decoder[i + 3][2], params.decoder.up_blocks[i].resnets[2])
        if i != 3:
            model.decoder[i + 3][4].load_state_dict(
                params.decoder.up_blocks[i].upsamplers[0].conv.state_dict())
    #decoder.out
    model.decoder[7][0].load_state_dict(params.decoder.conv_norm_out.state_dict())
    model.decoder[7][2].load_state_dict(params.decoder.conv_out.state_dict())
    return model

def load_res(model, param):
    model.seq[0].load_state_dict(param.norm1.state_dict())
    model.seq[2].load_state_dict(param.conv1.state_dict())
    model.seq[3].load_state_dict(param.norm2.state_dict())
    model.seq[5].load_state_dict(param.conv2.state_dict())
    if isinstance(model.resil, nn.Conv2d):
        model.resil.load_state_dict(param.conv_shortcut.state_dict())

def load_atten(model, param):
    model.norm.load_state_dict(param.group_norm.state_dict())
    model.wq.load_state_dict(param.to_q.state_dict())
    model.wk.load_state_dict(param.to_k.state_dict())
    model.wv.load_state_dict(param.to_v.state_dict())
    model.out_proj.load_state_dict(param.to_out[0].state_dict())

def vae_pretrained():
    vae = VAE()
    vae = load_pretrained(vae)
    return vae

去噪网络 Diffusion Model

Diffusion Model 使用 ResNet 块和 Transformer 块交替构成一个 UNet 网络。其中,Transformer 块中使用的。

'''
   用于去噪的神经网络 Diffusion Model
   unet.py
'''
import torch
from torch import nn
from diffusers import UNet2DConditionModel

class ResNetBlock(nn.Module):
    def __init__(self, dim_in, dim_out, time_emb_dim=1280):
        super().__init__()

        self.time_emb = nn.Sequential(
            nn.SiLU(),
            torch.nn.Linear(time_emb_dim, dim_out),
            nn.Unflatten(dim=1, unflattened_size=(dim_out, 1, 1)),
        )
        self.seq1 = nn.Sequential(
            nn.GroupNorm(num_groups=32,
                            num_channels=dim_in,
                            eps=1e-5,
                            affine=True),
            nn.SiLU(),
            nn.Conv2d(dim_in,
                        dim_out,
                        kernel_size=3,
                        stride=1,
                        padding=1))
        self.seq2 = nn.Sequential(
            nn.GroupNorm(num_groups=32,
                            num_channels=dim_out,
                            eps=1e-5,
                            affine=True),
            nn.SiLU(),
            nn.Conv2d(dim_out,
                        dim_out,
                        kernel_size=3,
                        stride=1,
                        padding=1))
        self.resil = nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1) if dim_in!=dim_out else nn.Identity()

    def forward(self, x, time):
        #x -> [1, 320, 64, 64]
        #time -> [1, 1280]
        res = x
        #[1, 1280] -> [1, 640, 1, 1]
        time = self.time_emb(time)
        #[1, 320, 64, 64] -> [1, 640, 32, 32]
        x = self.seq1(x) + time
        #维度不变
        #[1, 640, 32, 32]
        x = self.seq2(x)
        #[1, 320, 64, 64] -> [1, 640, 32, 32]
        #维度不变
        #[1, 640, 32, 32]
        x = self.resil(res)  + x
        return x

class CrossAttention(nn.Module):
    def __init__(self, dim_q=320, dim_kv=768, heads=8):
        super().__init__()
        #dim_q -> 320
        #dim_kv -> 768
        self.dim_q = dim_q
        self.heads = heads
        self.wq = nn.Linear(dim_q, dim_q, bias=False)
        self.wk = nn.Linear(dim_kv, dim_q, bias=False)
        self.wv = nn.Linear(dim_kv, dim_q, bias=False)
        self.out_proj = nn.Linear(dim_q, dim_q)
    def multihead_reshape(self, x):
        #x -> [1, 4096, 320]
        b, lens, dim = x.shape
        #[1, 4096, 320] -> [1, 4096, 8, 40]
        x = x.reshape(b, lens, self.heads, dim // self.heads)
        #[1, 4096, 8, 40] -> [1, 8, 4096, 40]
        x = x.transpose(1, 2)
        #[1, 8, 4096, 40] -> [8, 4096, 40]
        x = x.reshape(b * self.heads, lens, dim // self.heads)
        return x
    def multihead_reshape_inverse(self, x):
        #x -> [8, 4096, 40]
        b, lens, dim = x.shape
        #[8, 4096, 40] -> [1, 8, 4096, 40]
        x = x.reshape(b // self.heads, self.heads, lens, dim)
        #[1, 8, 4096, 40] -> [1, 4096, 8, 40]
        x = x.transpose(1, 2)
        #[1, 4096, 320]
        x = x.reshape(b // self.heads, lens, dim * self.heads)
        return x
    def forward(self, q, kv):
        #x -> [1, 4096, 320]
        #kv -> [1, 77, 768]
        #[1, 4096, 320] -> [1, 4096, 320]
        q = self.wq(q)
        #[1, 77, 768] -> [1, 77, 320]
        k = self.wk(kv)
        #[1, 77, 768] -> [1, 77, 320]
        v = self.wv(kv)

        #[1, 4096, 320] -> [8, 4096, 40]
        q = self.multihead_reshape(q)
        #[1, 77, 320] -> [8, 77, 40]
        k = self.multihead_reshape(k)
        #[1, 77, 320] -> [8, 77, 40]
        v = self.multihead_reshape(v)
        #[8, 4096, 40] * [8, 40, 77] -> [8, 4096, 77]
        atten = q.bmm(k.transpose(1, 2)) * (self.dim_q // self.heads)**-0.5
        atten = atten.softmax(dim=-1)
        #[8, 4096, 77] * [8, 77, 40] -> [8, 4096, 40]
        atten = atten.bmm(v)
        #[8, 4096, 40] -> [1, 4096, 320]
        atten = self.multihead_reshape_inverse(atten)
        #[1, 4096, 320] -> [1, 4096, 320]
        atten = self.out_proj(atten)
        return atten

class TransformerBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        #in
        self.norm_in = nn.GroupNorm(num_groups=32,
                                    num_channels=dim,
                                    eps=1e-6,
                                    affine=True)
        self.cnn_in = nn.Conv2d(dim,
                                dim,
                                kernel_size=1,
                                stride=1,
                                padding=0)
        #atten
        self.norm_atten0 = nn.LayerNorm(dim, elementwise_affine=True)
        self.atten1 = CrossAttention(dim, dim)
        self.norm_atten1 = nn.LayerNorm(dim, elementwise_affine=True)
        self.atten2 = CrossAttention(dim, 768)
        #act
        self.norm_act = nn.LayerNorm(dim, elementwise_affine=True)
        self.fc0 = nn.Linear(dim, dim * 8)
        self.act = nn.GELU()
        self.fc1 = nn.Linear(dim * 4, dim)
        #out
        self.cnn_out = nn.Conv2d(dim,
                                dim,
                                kernel_size=1,
                                stride=1,
                                padding=0)
    def forward(self, q, kv):
        #q -> [1, 320, 64, 64]
        #kv -> [1, 77, 768]
        b, _, h, w = q.shape
        res1 = q
        #----in----
        #维度不变
        #[1, 320, 64, 64]
        q = self.cnn_in(self.norm_in(q))
        #[1, 320, 64, 64] -> [1, 64, 64, 320] -> [1, 4096, 320]
        q = q.permute(0, 2, 3, 1).reshape(b, h * w, self.dim)
        #----atten----
        #维度不变
        #[1, 4096, 320]
        q = self.atten1(q=self.norm_atten0(q), kv=self.norm_atten0(q)) + q
        q = self.atten2(q=self.norm_atten1(q), kv=kv) + q
        #----act----
        #[1, 4096, 320]
        res2 = q
        #[1, 4096, 320] -> [1, 4096, 2560]
        q = self.fc0(self.norm_act(q))
        #1280
        d = q.shape[2] // 2
        #[1, 4096, 1280] * [1, 4096, 1280] -> [1, 4096, 1280]
        q = q[:, :, :d] * self.act(q[:, :, d:])
        #[1, 4096, 1280] -> [1, 4096, 320]
        q = self.fc1(q) + res2
        #----out----
        #[1, 4096, 320] -> [1, 64, 64, 320] -> [1, 320, 64, 64]
        q = q.reshape(b, h, w, self.dim).permute(0, 3, 1, 2).contiguous()
        #维度不变
        #[1, 320, 64, 64]
        q = self.cnn_out(q) + res1
        return q

class DownBlock(torch.nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.trans_block1 = TransformerBlock(dim_out)
        self.res_block1 = ResNetBlock(dim_in, dim_out)
        self.trans_block2 = TransformerBlock(dim_out)
        self.res_block2 = ResNetBlock(dim_out, dim_out)
        self.out = nn.Conv2d(dim_out,
                            dim_out,
                            kernel_size=3,
                            stride=2,
                            padding=1)
    def forward(self, vae_out, text_out, time):
        outs = []
        vae_out = self.res_block1(vae_out, time)
        vae_out = self.trans_block1(vae_out, text_out)
        outs.append(vae_out)

        vae_out = self.res_block2(vae_out, time)
        vae_out = self.trans_block2(vae_out, text_out)
        outs.append(vae_out)

        vae_out = self.out(vae_out)
        outs.append(vae_out)

        return vae_out, outs

class UpBlock(nn.Module):
    def __init__(self, dim_in, dim_out, dim_prev, add_up):
        super().__init__()

        self.res_block1 = ResNetBlock(dim_out + dim_prev, dim_out)
        self.res_block2 = ResNetBlock(dim_out + dim_out, dim_out)
        self.res_block3 = ResNetBlock(dim_in + dim_out, dim_out)

        self.trans_block1 = TransformerBlock(dim_out)
        self.trans_block2 = TransformerBlock(dim_out)
        self.trans_block3 = TransformerBlock(dim_out)

        self.out = torch.nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, padding=1),
        ) if add_up else nn.Identity()

    def forward(self, vae_out, text_out, time, out_down):
        vae_out = self.res_block1(torch.cat([vae_out, out_down.pop()], dim=1), time)
        vae_out = self.trans_block1(vae_out, text_out)

        vae_out = self.res_block2(torch.cat([vae_out, out_down.pop()], dim=1), time)
        vae_out = self.trans_block2(vae_out, text_out)

        vae_out = self.res_block3(torch.cat([vae_out, out_down.pop()], dim=1), time)
        vae_out = self.trans_block3(vae_out, text_out)

        vae_out = self.out(vae_out)
        return vae_out

class UNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        #in
        self.in_vae = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
        self.in_time = torch.nn.Sequential(
            torch.nn.Linear(320, 1280),
            torch.nn.SiLU(),
            torch.nn.Linear(1280, 1280),
        )
        #down
        self.down_block1 = DownBlock(320, 320)
        self.down_block2 = DownBlock(320, 640)
        self.down_block3 = DownBlock(640, 1280)

        self.down_res1 = ResNetBlock(1280, 1280)
        self.down_res2 = ResNetBlock(1280, 1280)
        #mid
        self.mid_res1 = ResNetBlock(1280, 1280)
        self.mid_trans = TransformerBlock(1280)
        self.mid_res2 = ResNetBlock(1280, 1280)
        #up
        self.up_res1 = ResNetBlock(2560, 1280)
        self.up_res2 = ResNetBlock(2560, 1280)
        self.up_res3 = ResNetBlock(2560, 1280)

        self.up_in = torch.nn.Sequential(
            torch.nn.Upsample(scale_factor=2, mode='nearest'),
            torch.nn.Conv2d(1280, 1280, kernel_size=3, padding=1),
        )

        self.up_block1 = UpBlock(640, 1280, 1280, True)
        self.up_block2 = UpBlock(320, 640, 1280, True)
        self.up_block3 = UpBlock(320, 320, 640, False)
        #out
        self.out = torch.nn.Sequential(
            torch.nn.GroupNorm(num_channels=320, num_groups=32, eps=1e-5),
            torch.nn.SiLU(),
            torch.nn.Conv2d(320, 4, kernel_size=3, padding=1),
        )
        # self.load_pretrained()
    
    def get_time_embed(self, t):
        #-9.210340371976184 = -math.log(10000)
        e = torch.arange(160) * -9.210340371976184 / 160
        e = e.exp().to(t.device) * t
        #[160+160] -> [320] -> [1, 320]
        e = torch.cat([e.cos(), e.sin()]).unsqueeze(dim=0)
        return e

    def forward(self, vae_out, text_out, time):
        #vae_out -> [1, 4, 64, 64]
        #out_encoder -> [1, 77, 768]
        #time -> [1]
        #----in----
        #[1, 4, 64, 64] -> [1, 320, 64, 64]
        vae_out = self.in_vae(vae_out)
        #[1] -> [1, 320]
        time = self.get_time_embed(time)
        #[1, 320] -> [1, 1280]
        time = self.in_time(time)

        #----down----
        out_down = [vae_out]
        #[1, 320, 64, 64],[1, 77, 768],[1, 1280] -> [1, 320, 32, 32]
        #out -> [1, 320, 64, 64],[1, 320, 64, 64][1, 320, 32, 32]
        vae_out, out = self.down_block1(vae_out=vae_out,
                                        text_out=text_out,
                                        time=time)
        out_down.extend(out)
        #[1, 320, 32, 32],[1, 77, 768],[1, 1280] -> [1, 640, 16, 16]
        #out -> [1, 640, 32, 32],[1, 640, 32, 32],[1, 640, 16, 16]
        vae_out, out = self.down_block2(vae_out=vae_out,
                                        text_out=text_out,
                                        time=time)
        out_down.extend(out)
        #[1, 640, 16, 16],[1, 77, 768],[1, 1280] -> [1, 1280, 8, 8]
        #out -> [1, 1280, 16, 16],[1, 1280, 16, 16],[1, 1280, 8, 8]
        vae_out, out = self.down_block3(vae_out=vae_out,
                                        text_out=text_out,
                                        time=time)
        out_down.extend(out)
        #[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
        vae_out = self.down_res1(vae_out, time)
        out_down.append(vae_out)
        #[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
        vae_out = self.down_res2(vae_out, time)
        out_down.append(vae_out)

        #----mid----
        #[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
        vae_out = self.mid_res1(vae_out, time)
        #[1, 1280, 8, 8],[1, 77, 768] -> [1, 1280, 8, 8]
        vae_out = self.mid_trans(vae_out, text_out)
        #[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
        vae_out = self.mid_res2(vae_out, time)

        #----up----
        #[1, 1280+1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
        vae_out = self.up_res1(torch.cat([vae_out, out_down.pop()], dim=1),
                               time)
        #[1, 1280+1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
        vae_out = self.up_res2(torch.cat([vae_out, out_down.pop()], dim=1),
                               time)
        #[1, 1280+1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
        vae_out = self.up_res3(torch.cat([vae_out, out_down.pop()], dim=1),
                               time)
        #[1, 1280, 8, 8] -> [1, 1280, 16, 16]
        vae_out = self.up_in(vae_out)
        #[1, 1280, 16, 16],[1, 77, 768],[1, 1280] -> [1, 1280, 32, 32]
        #out_down -> [1, 640, 16, 16],[1, 1280, 16, 16],[1, 1280, 16, 16]
        vae_out = self.up_block1(vae_out=vae_out,
                                 text_out=text_out,
                                 time=time,
                                 out_down=out_down)
        #[1, 1280, 32, 32],[1, 77, 768],[1, 1280] -> [1, 640, 64, 64]
        #out_down -> [1, 320, 32, 32],[1, 640, 32, 32],[1, 640, 32, 32]
        vae_out = self.up_block2(vae_out=vae_out,
                                 text_out=text_out,
                                 time=time,
                                 out_down=out_down)
        #[1, 640, 64, 64],[1, 77, 768],[1, 1280] -> [1, 320, 64, 64]
        #out_down -> [1, 320, 64, 64],[1, 320, 64, 64],[1, 320, 64, 64]
        vae_out = self.up_block3(vae_out=vae_out,
                                 text_out=text_out,
                                 time=time,
                                 out_down=out_down)
        #----out----
        #[1, 320, 64, 64] -> [1, 4, 64, 64]
        vae_out = self.out(vae_out)
        return vae_out

# unet整体参数加载
def load_pretrained(model):
    params = UNet2DConditionModel.from_pretrained('./pretrained-params/', subfolder='unet')
    #in
    model.in_vae.load_state_dict(params.conv_in.state_dict())
    model.in_time[0].load_state_dict(params.time_embedding.linear_1.state_dict())
    model.in_time[2].load_state_dict(params.time_embedding.linear_2.state_dict())
    # down
    load_down_block(model.down_block1, params.down_blocks[0])
    load_down_block(model.down_block2, params.down_blocks[1])
    load_down_block(model.down_block3, params.down_blocks[2])

    load_res_block(model.down_res1, params.down_blocks[3].resnets[0])
    load_res_block(model.down_res2, params.down_blocks[3].resnets[1])
    # mid
    load_transformer_block(model.mid_trans, params.mid_block.attentions[0])
    load_res_block(model.mid_res1, params.mid_block.resnets[0])
    load_res_block(model.mid_res2, params.mid_block.resnets[1])
    #up
    load_res_block(model.up_res1, params.up_blocks[0].resnets[0])
    load_res_block(model.up_res2, params.up_blocks[0].resnets[1])
    load_res_block(model.up_res3, params.up_blocks[0].resnets[2])
    model.up_in[1].load_state_dict(
        params.up_blocks[0].upsamplers[0].conv.state_dict())
    load_up_block(model.up_block1, params.up_blocks[1])
    load_up_block(model.up_block2, params.up_blocks[2])
    load_up_block(model.up_block3, params.up_blocks[3])
    #out
    model.out[0].load_state_dict(params.conv_norm_out.state_dict())
    model.out[2].load_state_dict(params.conv_out.state_dict())
    return model

# transformer块参数加载
def load_transformer_block(model: TransformerBlock, param):
    model.norm_in.load_state_dict(param.norm.state_dict())
    model.cnn_in.load_state_dict(param.proj_in.state_dict())

    model.atten1.wq.load_state_dict(
        param.transformer_blocks[0].attn1.to_q.state_dict())
    model.atten1.wk.load_state_dict(
        param.transformer_blocks[0].attn1.to_k.state_dict())
    model.atten1.wv.load_state_dict(
        param.transformer_blocks[0].attn1.to_v.state_dict())
    model.atten1.out_proj.load_state_dict(
        param.transformer_blocks[0].attn1.to_out[0].state_dict())

    model.atten2.wq.load_state_dict(
        param.transformer_blocks[0].attn2.to_q.state_dict())
    model.atten2.wk.load_state_dict(
        param.transformer_blocks[0].attn2.to_k.state_dict())
    model.atten2.wv.load_state_dict(
        param.transformer_blocks[0].attn2.to_v.state_dict())
    model.atten2.out_proj.load_state_dict(
        param.transformer_blocks[0].attn2.to_out[0].state_dict())

    model.fc0.load_state_dict(
        param.transformer_blocks[0].ff.net[0].proj.state_dict())

    model.fc1.load_state_dict(
        param.transformer_blocks[0].ff.net[2].state_dict())

    model.norm_atten0.load_state_dict(
        param.transformer_blocks[0].norm1.state_dict())
    model.norm_atten1.load_state_dict(
        param.transformer_blocks[0].norm2.state_dict())
    model.norm_act.load_state_dict(
        param.transformer_blocks[0].norm3.state_dict())

    model.cnn_out.load_state_dict(param.proj_out.state_dict())

# resnet 块参数加载
def load_res_block(model: ResNetBlock, param):
    model.time_emb[1].load_state_dict(param.time_emb_proj.state_dict())

    model.seq1[0].load_state_dict(param.norm1.state_dict())
    model.seq1[2].load_state_dict(param.conv1.state_dict())

    model.seq2[0].load_state_dict(param.norm2.state_dict())
    model.seq2[2].load_state_dict(param.conv2.state_dict())

    if isinstance(model.resil, nn.Conv2d):
        model.resil.load_state_dict(param.conv_shortcut.state_dict())

# 下采样块参数加载
def load_down_block(model: DownBlock, param):
    load_transformer_block(model.trans_block1, param.attentions[0])
    load_transformer_block(model.trans_block2, param.attentions[1])

    load_res_block(model.res_block1, param.resnets[0])
    load_res_block(model.res_block2, param.resnets[1])
    model.out.load_state_dict(param.downsamplers[0].conv.state_dict())

# 上采样块参数加载
def load_up_block(model: UpBlock, param):
    load_transformer_block(model.trans_block1, param.attentions[0])
    load_transformer_block(model.trans_block2, param.attentions[1])
    load_transformer_block(model.trans_block3, param.attentions[2])

    load_res_block(model.res_block1, param.resnets[0])
    load_res_block(model.res_block2, param.resnets[1])
    load_res_block(model.res_block3, param.resnets[2])
    if isinstance(model.out, nn.Sequential):
        model.out[1].load_state_dict(param.upsamplers[0].conv.state_dict())

def unet_pretrained():
    unet = UNet()
    unet = load_pretrained(unet)
    return unet

训练和生成图片

项目结构

工作区目录为 latent_diffusion。在子目录中,data 存放训练数据,output 存放生成的图片,pretrained-params 中的是从参考项目的 hugging face 中下载的训练权重文件和配置文件。

# project.log
latent_diffusion
├── data
│   └── train.parquet
├── output
├── pretrained-params
│   ├── feature_extractor
│   │   └── preprocessor_config.json
│   ├── model_index.json
│   ├── scheduler
│   │   └── scheduler_config.json
│   ├── text_encoder
│   │   ├── config.json
│   │   └── pytorch_model.bin
│   ├── tokenizer
│   │   ├── merges.txt
│   │   ├── special_tokens_map.json
│   │   ├── tokenizer_config.json
│   │   └── vocab.json
│   ├── unet
│   │   ├── config.json
│   │   └── diffusion_pytorch_model.bin
│   └── vae
│       ├── config.json
│       └── diffusion_pytorch_model.bin
├── project.log
├── test.py
├── text_encoder.py
├── train.py
├── train_record.log
├── unet.py
├── utils.py
└── vision_auto_encoder.py

10 directories, 25 files

上述目录结构中的数据文件和权重文件(例如 pytorch_model.bin)请到 参考的 Hugging Face项目中下载 数据链接权重链接
因为在使用diffusers时不需要联网下载训练参数,有几个文件需要改动,所有 config.json 的键 “_name_or_path” 的值全部修改为对应的子目录绝对位置。

训练

与参考项目的代码相比,增加了混合精度和检查点。训练时并不训练全部的网络,仅训练 UNet 网络部分。文本编码器和VAE的编码器部分参与到了训练过程,但是这两个部分的参数都是冻结的。

'''
	训练代码 train.py
'''
import os,torch
from diffusers import DiffusionPipeline
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from text_encoder import text_encoder_pretrained
from vision_auto_encoder import vae_pretrained
from unet import unet_pretrained

from torch.cuda.amp import autocast, GradScaler
from PIL import Image
import io
import warnings
warnings.filterwarnings('ignore')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# diffusion model 噪声生成器
pipeline = DiffusionPipeline.from_pretrained('./pretrained-params/', safety_checker=None, local_files_only=True)
scheduler = pipeline.scheduler
tokenizer = pipeline.tokenizer
del pipeline

print('Device :', device)
print('Scheduler settings: ', scheduler)
print('Tokenizer settings: ', tokenizer)

# 数据处理与加载
dataset = load_dataset('parquet',data_files={'train':'./data/train.parquet'}, split='train')
compose = transforms.Compose([
    transforms.Resize((512,512), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.CenterCrop((512,512)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

def pair_text_image_process(data):
    # 对图像数据进行增强处理
    pixel = [compose(Image.open(io.BytesIO(i['bytes']))) for i in data['image']]
    # 文本
    text = tokenizer.batch_encode_plus(data['text'], padding='max_length', truncation=True, max_length=77).input_ids
    return {'pixel_values': pixel, 'input_ids': text}

dataset = dataset.map(pair_text_image_process, batched=True, num_proc=1, remove_columns=['image', 'text'])
dataset.set_format(type='torch')

def collate_fn(data):
    pixel = [i['pixel_values'] for i in data]
    text = [i['input_ids'] for i in data]
    pixel = torch.stack(pixel).to(device)
    text = torch.stack(text).to(device)
    return {'pixel': pixel, 'text': text}

loader = DataLoader(dataset, shuffle=True, collate_fn=collate_fn, batch_size=1)

# 模型加载
text_encoder = text_encoder_pretrained()
vision_encoder = vae_pretrained()
unet = unet_pretrained()

text_encoder.eval()
vision_encoder.eval()
unet.train()

text_encoder.to(device)
vision_encoder.to(device)
unet.to(device)
# 优化器, 损失函数, 混合精度
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-5, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-8)
criterion = torch.nn.MSELoss().to(device)
scaler = GradScaler()
# 一个epoch的训练
def train_one_epoch(unet, text_encoder, vision_encoder, train_loader, optimizer, criterion, noise_scheduler, scaler):
    loss_epoch = 0.
    for step, pair in enumerate(train_loader):
        img = pair['pixel']
        text = pair['text']
        with torch.no_grad():
            # 文本编码
            text_out = text_encoder(text)
            # 图像特征
            vision_out = vision_encoder.encoder(img)
            vision_out = vision_encoder.sample(vision_out)
            vision_out = vision_out * 0.18215

        # 添加噪声
        noise = torch.randn_like(vision_out)
        noise_step = torch.randint(0, 1000, (1,)).long().to(device)
        vision_out_noise = noise_scheduler.add_noise(vision_out, noise, noise_step)

        with autocast():
            noise_pred = unet(vision_out_noise, text_out, noise_step)
            loss = criterion(noise_pred, noise)
        
        loss_epoch += loss.item()
        # loss.backward()
        scaler.scale(loss).backward()
        # optimizer.step()
        scaler.step(optimizer)
        scaler.update()
        torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
        optimizer.zero_grad()

        print(f'step: {step}  loss: {loss.item():.8f}')
    
    return loss_epoch
# 检查点保存
def save_checkpoint(model, optimizer, epoch, loss, last=False):
    state = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss
    }
    torch.save(state, f'/hy-tmp/checkpoints/checkpoint_{epoch}.pth.tar')
    if last:
        torch.save(state, '/hy-tmp/checkpoints/last_checkpoint.pth.tar')

epochs = 100
loss_recorder = []
print('start training ...')
for epoch in range(epochs):
    epoch_loss = train_one_epoch(unet, text_encoder, vision_encoder, loader, optimizer, criterion, scheduler, scaler)
    
    save_checkpoint(unet, optimizer, epoch, epoch_loss, True)
    loss_recorder.append((epoch, epoch_loss))
    loss_recorder = sorted(loss_recorder, key=lambda e:e[-1])
    if len(loss_recorder) > 10:
        del_check = loss_recorder.pop()
        os.remove(f'/hy-tmp/checkpoints/checkpoint_{del_check[0]}.pth.tar')
        
    print(f'epoch: {epoch:03}  loss: {epoch_loss:.8f}')

    if epoch % 1 == 0:
        print('Top 10 checkpoints:')
        for i in loss_recorder:
            print(i)

print('end training.')

象征性的训练了100个epoch,有些不稳定。跑实验训练所需要的显存为24GB。
train_curve

图 2. 训练损失值曲线

文本生成图片

文本生成图片的过程使用到了文本编码器,UNet网络和VAE的解码器部分。UNet接受文本编码器的正负文本输出和并根据迭代步数不断预测噪声分布。最后,UNet的输出给到VAE的解码器生成一张图片。

'''
	文本生成图片 test.py
'''
from diffusers import DiffusionPipeline
import torch
from text_encoder import text_encoder_pretrained
from vision_auto_encoder import vae_pretrained
from unet import unet_pretrained

from PIL import Image
import numpy as np
import warnings
warnings.filterwarnings('ignore')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# diffusion model 噪声生成器
pipeline = DiffusionPipeline.from_pretrained('./pretrained-params/', safety_checker=None, local_files_only=True)
scheduler = pipeline.scheduler
tokenizer = pipeline.tokenizer
del pipeline

print('Device :', device)
print('Scheduler settings: ', scheduler)
print('Tokenizer settings: ', tokenizer)

# 模型加载
text_encoder = text_encoder_pretrained()
vision_encoder = vae_pretrained()
unet = unet_pretrained()

text_encoder.eval()
vision_encoder.eval()
unet.eval()

text_encoder.to(device)
vision_encoder.to(device)
unet.to(device)

# 根据文本生成图像
@torch.no_grad()
def generate(text, flag='gen_img'):
    # 词编码 [1, 77]
    pos = tokenizer(text, padding='max_length', max_length=77, 
                    truncation=True, return_tensors='pt').input_ids.to(device)
    neg = tokenizer('', padding='max_length', max_length=77,
                    truncation=True, return_tensors='pt').input_ids.to(device)
    
    pos_out = text_encoder(pos) # (1, 77, 768)
    neg_out = text_encoder(neg) # -
    text_out = torch.cat((neg_out, pos_out), dim=0) # (2, 77, 768)
    # 全噪声图
    vae_out = torch.randn(1,4,64,64, device=device)
    # 生成时间步
    scheduler.set_timesteps(50, device=device)

    for time in scheduler.timesteps:
        noise = torch.cat((vae_out, vae_out), dim=0)
        noise = scheduler.scale_model_input(noise, time)
        # 预测噪声分布
        # print('text out', text_out.shape)
        pred_noise = unet(vae_out=noise, text_out=text_out, time=time)
        # 降噪
        pred_noise = pred_noise[0] + 7.5 * (pred_noise[1] - pred_noise[0])
        # 继续添加噪声
        vae_out = scheduler.step(pred_noise, time, vae_out).prev_sample
    
    # 从压缩图恢复成图片
    vae_out = 1/0.18215 * vae_out
    image = vision_encoder.decoder(vae_out)
    # 转换并保存
    image = image.cpu()
    image = (image + 1) / 2
    image = image.clamp(0, 1)
    image = image.permute(0, 2, 3, 1)
    image = image.numpy()[0]
    image = Image.fromarray(np.uint8(image*255))
    image.save(f'./output/{flag}.jpg')

texts = [
    'a drawing of a star with a jewel in the center',
    'a drawing of a woman in a red cape',
    'a drawing of a dragon sitting on its hind legs',
    'a drawing of a blue sea turtle holding a rock',
    'a blue and white bird with its wings spread',
    'a blue and white stuffed animal sitting on top of a white surface',
    'a teddy bear sitting on a desk',
]
images = []
for i,text in enumerate(texts):
    image = generate(text, f'gen_img{i}')
    print(f'text: {text}, finished')        

生成的图片

text_images

图 3. 根据文本生成的图片
  • 21
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值