Latent Diffusion 学习实例
参考
根据 github 项目 的学习笔记,网络结构与原作者项目相同,但是根据博主的习惯对代码进行了些许改动。原作者项目的链接如下:
网站 | 链接 |
---|---|
Github | https://github.com/lansinuote/Diffusion_From_Scratch |
Huggingface | https://huggingface.co/datasets/lansinuote/diffsion_from_scratch |
这篇博文的项目地址
网站 | 链接 |
---|---|
Github | https://github.com/MarcYugo/a-practice-example-latent-diffusion |
GitCode | https://gitcode.com/weixin_43385826/a-practice-example-latent-diffusion/ |
依赖
除了 Pytorch 外还需要下面几个包。
transformers==4.26.1
datasets==2.9.0
diffusers==0.12.1
整体结构
该神经网络模型包括三部分,抽取文本特征的文本编码器(Text Encoder)、抽取图像特征的图像编码器(Image Encoder)和 用于去噪的 UNet 网络(Diffusion Model),如 图1 所示。
文本编码器 Text Encoder
Text Encoder 采用带有多头注意力的 Transformer 结构并使用 CLIP 预训练模型的参数。具体结构是包括:一个用于扩增维度和位置编码的Embedding层,12个相同结构的带掩码的多头自注意力(Multi-Head Self-Attentiom, MHSA)层 和最后一层 LayerNorm层。
'''
用于提取文本特征的编码器
text_encoder.py
'''
import torch
from torch import nn
from transformers import CLIPTextModel
class Embed(nn.Module):
def __init__(self, embed_dim=768, n_tokens=77, seq_len=49408):
super().__init__()
self.embed = nn.Embedding(seq_len, embed_dim)
self.pos_embed = nn.Embedding(n_tokens, embed_dim)
self.embed_dim = embed_dim
self.n_tokens = n_tokens
self.register_buffer('pos_ids', torch.arange(n_tokens).unsqueeze(0))
def forward(self, input_ids):
# input_ids: (b, 77)
embed = self.embed(input_ids)
pos_embed = self.pos_embed(self.pos_ids)
return embed + pos_embed
class SelfAttention(nn.Module):
def __init__(self, emb_dim=768, heads=12):
super().__init__()
self.wq = nn.Linear(emb_dim, emb_dim)
self.wk = nn.Linear(emb_dim, emb_dim)
self.wv = nn.Linear(emb_dim, emb_dim)
self.out_proj = nn.Linear(emb_dim, emb_dim)
self.emb_dim = emb_dim
self.heads = heads
def get_mask(self, b, n_tok):
mask = torch.empty(b, n_tok, n_tok)
mask.fill_(-float('inf'))
mask.triu_(1).unsqueeze(1)
return mask
def forward(self, x):
# (b, 77, 768)
b, n_tok, _ = x.shape
q = self.wq(x)/8
k = self.wk(x)
v = self.wv(x)
# 注意力头拆分
q = q.reshape(b, n_tok, self.heads, self.emb_dim//self.heads).transpose(1,2).reshape(b*self.heads, n_tok, self.emb_dim//self.heads)
k = k.reshape(b, n_tok, self.heads, self.emb_dim//self.heads).transpose(1,2).reshape(b*self.heads, n_tok, self.emb_dim//self.heads)
v = v.reshape(b, n_tok, self.heads, self.emb_dim//self.heads).transpose(1,2).reshape(b*self.heads, n_tok, self.emb_dim//self.heads)
# 计算q,k乘积, qk关系矩阵
atten = torch.bmm(q, k.transpose(1,2))
atten = atten.reshape(b, self.heads, n_tok, n_tok)
atten = atten + self.get_mask(b, n_tok).to(atten.device)
atten = atten.reshape(b*self.heads, n_tok, n_tok)
atten = atten.softmax(dim=-1)
atten = torch.bmm(atten, v) # (b*12, 77, 77)
atten = atten.reshape(b, self.heads, n_tok, self.emb_dim//self.heads).transpose(1,2).reshape(b, n_tok, self.emb_dim) # (b, 77, 768)
out = self.out_proj(atten)
return out
class QuickGELU(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x*(x*1.702).sigmoid()
class Block(nn.Module):
def __init__(self, embed_dim=768, expand_dim=3072):
super().__init__()
self.seq1 = nn.Sequential(
nn.LayerNorm(embed_dim),
SelfAttention())
self.seq2 = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, expand_dim),
QuickGELU(),
nn.Linear(expand_dim, embed_dim))
def forward(self, x):
x = x + self.seq1(x)
x = x + self.seq2(x)
return x
class TextEncoder(nn.Module):
def __init__(self, embed_dim=768):
super().__init__()
seq = [Embed()] + [Block() for _ in range(12)] + [nn.LayerNorm(embed_dim)]
self.seq = nn.Sequential(*seq)
def forward(self, x):
out = self.seq(x)
return out
def load_pretrained(model):
params = CLIPTextModel.from_pretrained('./pretrained-params', subfolder='text_encoder')
model.seq[0].embed.load_state_dict(
params.text_model.embeddings.token_embedding.state_dict())
model.seq[0].pos_embed.load_state_dict(
params.text_model.embeddings.position_embedding.state_dict())
for i in range(12):
model.seq[i+1].seq1[0].load_state_dict(
params.text_model.encoder.layers[i].layer_norm1.state_dict())
model.seq[i+1].seq1[1].wq.load_state_dict(
params.text_model.encoder.layers[i].self_attn.q_proj.state_dict())
model.seq[i+1].seq1[1].wk.load_state_dict(
params.text_model.encoder.layers[i].self_attn.k_proj.state_dict())
model.seq[i+1].seq1[1].wv.load_state_dict(
params.text_model.encoder.layers[i].self_attn.v_proj.state_dict())
model.seq[i+1].seq1[1].out_proj.load_state_dict(
params.text_model.encoder.layers[i].self_attn.out_proj.state_dict())
model.seq[i+1].seq2[0].load_state_dict(
params.text_model.encoder.layers[i].layer_norm2.state_dict())
model.seq[i+1].seq2[1].load_state_dict(
params.text_model.encoder.layers[i].mlp.fc1.state_dict())
model.seq[i+1].seq2[3].load_state_dict(
params.text_model.encoder.layers[i].mlp.fc2.state_dict())
model.seq[13].load_state_dict(params.text_model.final_layer_norm.state_dict())
return model
def text_encoder_pretrained():
text_encoder = TextEncoder()
text_encoder = load_pretrained(text_encoder)
return text_encoder
图像编码器 Image Encoder
Image Encoder 使用的模型是VAE(Variational Autoencoder),由 ResNet 块结构和 少量的 自注意力(Self-Attention)层堆叠构成 的一个Encoder 和一个 Decoder 组成。其中,所有标准化层和激活函数是 GroupNorm 标准化层 和 SiLU 激活函数。
'''
用于提取视觉特征的编码器 VAE
vision_auto_encoder.py
'''
import torch
from torch import nn
from diffusers import AutoencoderKL
class ResNetBlock(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.seq = nn.Sequential(
nn.GroupNorm(num_groups=32, num_channels=dim_in, eps=1e-6, affine=True),
nn.SiLU(),
nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1),
nn.GroupNorm(num_groups=32, num_channels=dim_out, eps=1e-6, affine=True),
nn.SiLU(),
nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1))
self.resil = nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1) if dim_in != dim_out else nn.Identity()
self.dim_in = dim_in
self.dim_out = dim_out
def forward(self, x):
res = self.resil(x)
out = self.seq(x) + res
return out
class SelfAttention(nn.Module):
def __init__(self, embed_dim=512):
super().__init__()
self.norm = nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6, affine=True)
self.wq = torch.nn.Linear(embed_dim, embed_dim)
self.wk = torch.nn.Linear(embed_dim, embed_dim)
self.wv = torch.nn.Linear(embed_dim, embed_dim)
self.out_proj = torch.nn.Linear(embed_dim, embed_dim)
def forward(self, x):
# x: (b, 512, 64, 64)
res = x
b,c,h,w = x.shape
x = self.norm(x)
x = x.flatten(start_dim=2).transpose(1,2) # (1, 4096, 512)
q = self.wq(x) # (1, 4096, 512)
k = self.wk(x)
v = self.wv(x)
k = k.transpose(1,2) # (1, 512, 4096
#[1, 4096, 512] * [1, 512, 4096] -> [1, 4096, 4096]
#0.044194173824159216 = 1 / 512**0.5
atten = q.bmm(k) / 512**0.5
atten = torch.softmax(atten, dim=2)
atten = atten.bmm(v)
atten = self.out_proj(atten)
atten = atten.transpose(1, 2).reshape(b, c, h, w)
atten = atten + res
return atten
class Pad(nn.Module):
def __init__(self):
super().__init__()
def forward(self,x):
return nn.functional.pad(x, (0, 1, 0, 1),
mode='constant',
value=0)
class VAE(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
#in
nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1),
#down
nn.Sequential(
ResNetBlock(128, 128),
ResNetBlock(128, 128),
nn.Sequential(
Pad(),
torch.nn.Conv2d(128, 128, 3, stride=2, padding=0),
),
),
torch.nn.Sequential(
ResNetBlock(128, 256),
ResNetBlock(256, 256),
torch.nn.Sequential(
Pad(),
nn.Conv2d(256, 256, 3, stride=2, padding=0),
),
),
nn.Sequential(
ResNetBlock(256, 512),
ResNetBlock(512, 512),
nn.Sequential(
Pad(),
nn.Conv2d(512, 512, 3, stride=2, padding=0),
),
),
nn.Sequential(
ResNetBlock(512, 512),
ResNetBlock(512, 512),
),
#mid
nn.Sequential(
ResNetBlock(512, 512),
SelfAttention(),
ResNetBlock(512, 512),
),
#out
nn.Sequential(
nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6),
nn.SiLU(),
nn.Conv2d(512, 8, 3, padding=1),
),
#正态分布层
nn.Conv2d(8, 8, 1))
self.decoder = nn.Sequential(
#正态分布层
nn.Conv2d(4, 4, 1),
#in
nn.Conv2d(4, 512, kernel_size=3, stride=1, padding=1),
#middle
nn.Sequential(ResNetBlock(512, 512),
SelfAttention(),
ResNetBlock(512, 512)),
#up
nn.Sequential(
ResNetBlock(512, 512),
ResNetBlock(512, 512),
ResNetBlock(512, 512),
nn.Upsample(scale_factor=2.0, mode='nearest'),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
),
nn.Sequential(
ResNetBlock(512, 512),
ResNetBlock(512, 512),
ResNetBlock(512, 512),
nn.Upsample(scale_factor=2.0, mode='nearest'),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
),
nn.Sequential(
ResNetBlock(512, 256),
ResNetBlock(256, 256),
ResNetBlock(256, 256),
nn.Upsample(scale_factor=2.0, mode='nearest'),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
),
nn.Sequential(
ResNetBlock(256, 128),
ResNetBlock(128, 128),
ResNetBlock(128, 128),
),
#out
nn.Sequential(
nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6),
nn.SiLU(),
nn.Conv2d(128, 3, 3, padding=1),
))
def sample(self, h):
#h -> [1, 8, 64, 64]
#[1, 4, 64, 64]
mean = h[:, :4]
logvar = h[:, 4:]
std = logvar.exp()**0.5
#[1, 4, 64, 64]
h = torch.randn(mean.shape, device=mean.device)
h = mean + std * h
return h
def forward(self, x):
#x -> [1, 3, 512, 512]
#[1, 3, 512, 512] -> [1, 8, 64, 64]
h = self.encoder(x)
#[1, 8, 64, 64] -> [1, 4, 64, 64]
h = self.sample(h)
#[1, 4, 64, 64] -> [1, 3, 512, 512]
h = self.decoder(h)
return h
def load_pretrained(model):
params = AutoencoderKL.from_pretrained('./pretrained-params/', subfolder='vae')
model.encoder[0].load_state_dict(params.encoder.conv_in.state_dict())
#encoder.down
for i in range(4):
load_res(model.encoder[i + 1][0], params.encoder.down_blocks[i].resnets[0])
load_res(model.encoder[i + 1][1], params.encoder.down_blocks[i].resnets[1])
if i != 3:
model.encoder[i + 1][2][1].load_state_dict(
params.encoder.down_blocks[i].downsamplers[0].conv.state_dict())
#encoder.mid
load_res(model.encoder[5][0], params.encoder.mid_block.resnets[0])
load_res(model.encoder[5][2], params.encoder.mid_block.resnets[1])
load_atten(model.encoder[5][1], params.encoder.mid_block.attentions[0])
#encoder.out
model.encoder[6][0].load_state_dict(params.encoder.conv_norm_out.state_dict())
model.encoder[6][2].load_state_dict(params.encoder.conv_out.state_dict())
#encoder.正态分布层
model.encoder[7].load_state_dict(params.quant_conv.state_dict())
#decoder.正态分布层
model.decoder[0].load_state_dict(params.post_quant_conv.state_dict())
#decoder.in
model.decoder[1].load_state_dict(params.decoder.conv_in.state_dict())
#decoder.mid
load_res(model.decoder[2][0], params.decoder.mid_block.resnets[0])
load_res(model.decoder[2][2], params.decoder.mid_block.resnets[1])
load_atten(model.decoder[2][1], params.decoder.mid_block.attentions[0])
#decoder.up
for i in range(4):
load_res(model.decoder[i + 3][0], params.decoder.up_blocks[i].resnets[0])
load_res(model.decoder[i + 3][1], params.decoder.up_blocks[i].resnets[1])
load_res(model.decoder[i + 3][2], params.decoder.up_blocks[i].resnets[2])
if i != 3:
model.decoder[i + 3][4].load_state_dict(
params.decoder.up_blocks[i].upsamplers[0].conv.state_dict())
#decoder.out
model.decoder[7][0].load_state_dict(params.decoder.conv_norm_out.state_dict())
model.decoder[7][2].load_state_dict(params.decoder.conv_out.state_dict())
return model
def load_res(model, param):
model.seq[0].load_state_dict(param.norm1.state_dict())
model.seq[2].load_state_dict(param.conv1.state_dict())
model.seq[3].load_state_dict(param.norm2.state_dict())
model.seq[5].load_state_dict(param.conv2.state_dict())
if isinstance(model.resil, nn.Conv2d):
model.resil.load_state_dict(param.conv_shortcut.state_dict())
def load_atten(model, param):
model.norm.load_state_dict(param.group_norm.state_dict())
model.wq.load_state_dict(param.to_q.state_dict())
model.wk.load_state_dict(param.to_k.state_dict())
model.wv.load_state_dict(param.to_v.state_dict())
model.out_proj.load_state_dict(param.to_out[0].state_dict())
def vae_pretrained():
vae = VAE()
vae = load_pretrained(vae)
return vae
去噪网络 Diffusion Model
Diffusion Model 使用 ResNet 块和 Transformer 块交替构成一个 UNet 网络。其中,Transformer 块中使用的。
'''
用于去噪的神经网络 Diffusion Model
unet.py
'''
import torch
from torch import nn
from diffusers import UNet2DConditionModel
class ResNetBlock(nn.Module):
def __init__(self, dim_in, dim_out, time_emb_dim=1280):
super().__init__()
self.time_emb = nn.Sequential(
nn.SiLU(),
torch.nn.Linear(time_emb_dim, dim_out),
nn.Unflatten(dim=1, unflattened_size=(dim_out, 1, 1)),
)
self.seq1 = nn.Sequential(
nn.GroupNorm(num_groups=32,
num_channels=dim_in,
eps=1e-5,
affine=True),
nn.SiLU(),
nn.Conv2d(dim_in,
dim_out,
kernel_size=3,
stride=1,
padding=1))
self.seq2 = nn.Sequential(
nn.GroupNorm(num_groups=32,
num_channels=dim_out,
eps=1e-5,
affine=True),
nn.SiLU(),
nn.Conv2d(dim_out,
dim_out,
kernel_size=3,
stride=1,
padding=1))
self.resil = nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1) if dim_in!=dim_out else nn.Identity()
def forward(self, x, time):
#x -> [1, 320, 64, 64]
#time -> [1, 1280]
res = x
#[1, 1280] -> [1, 640, 1, 1]
time = self.time_emb(time)
#[1, 320, 64, 64] -> [1, 640, 32, 32]
x = self.seq1(x) + time
#维度不变
#[1, 640, 32, 32]
x = self.seq2(x)
#[1, 320, 64, 64] -> [1, 640, 32, 32]
#维度不变
#[1, 640, 32, 32]
x = self.resil(res) + x
return x
class CrossAttention(nn.Module):
def __init__(self, dim_q=320, dim_kv=768, heads=8):
super().__init__()
#dim_q -> 320
#dim_kv -> 768
self.dim_q = dim_q
self.heads = heads
self.wq = nn.Linear(dim_q, dim_q, bias=False)
self.wk = nn.Linear(dim_kv, dim_q, bias=False)
self.wv = nn.Linear(dim_kv, dim_q, bias=False)
self.out_proj = nn.Linear(dim_q, dim_q)
def multihead_reshape(self, x):
#x -> [1, 4096, 320]
b, lens, dim = x.shape
#[1, 4096, 320] -> [1, 4096, 8, 40]
x = x.reshape(b, lens, self.heads, dim // self.heads)
#[1, 4096, 8, 40] -> [1, 8, 4096, 40]
x = x.transpose(1, 2)
#[1, 8, 4096, 40] -> [8, 4096, 40]
x = x.reshape(b * self.heads, lens, dim // self.heads)
return x
def multihead_reshape_inverse(self, x):
#x -> [8, 4096, 40]
b, lens, dim = x.shape
#[8, 4096, 40] -> [1, 8, 4096, 40]
x = x.reshape(b // self.heads, self.heads, lens, dim)
#[1, 8, 4096, 40] -> [1, 4096, 8, 40]
x = x.transpose(1, 2)
#[1, 4096, 320]
x = x.reshape(b // self.heads, lens, dim * self.heads)
return x
def forward(self, q, kv):
#x -> [1, 4096, 320]
#kv -> [1, 77, 768]
#[1, 4096, 320] -> [1, 4096, 320]
q = self.wq(q)
#[1, 77, 768] -> [1, 77, 320]
k = self.wk(kv)
#[1, 77, 768] -> [1, 77, 320]
v = self.wv(kv)
#[1, 4096, 320] -> [8, 4096, 40]
q = self.multihead_reshape(q)
#[1, 77, 320] -> [8, 77, 40]
k = self.multihead_reshape(k)
#[1, 77, 320] -> [8, 77, 40]
v = self.multihead_reshape(v)
#[8, 4096, 40] * [8, 40, 77] -> [8, 4096, 77]
atten = q.bmm(k.transpose(1, 2)) * (self.dim_q // self.heads)**-0.5
atten = atten.softmax(dim=-1)
#[8, 4096, 77] * [8, 77, 40] -> [8, 4096, 40]
atten = atten.bmm(v)
#[8, 4096, 40] -> [1, 4096, 320]
atten = self.multihead_reshape_inverse(atten)
#[1, 4096, 320] -> [1, 4096, 320]
atten = self.out_proj(atten)
return atten
class TransformerBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
#in
self.norm_in = nn.GroupNorm(num_groups=32,
num_channels=dim,
eps=1e-6,
affine=True)
self.cnn_in = nn.Conv2d(dim,
dim,
kernel_size=1,
stride=1,
padding=0)
#atten
self.norm_atten0 = nn.LayerNorm(dim, elementwise_affine=True)
self.atten1 = CrossAttention(dim, dim)
self.norm_atten1 = nn.LayerNorm(dim, elementwise_affine=True)
self.atten2 = CrossAttention(dim, 768)
#act
self.norm_act = nn.LayerNorm(dim, elementwise_affine=True)
self.fc0 = nn.Linear(dim, dim * 8)
self.act = nn.GELU()
self.fc1 = nn.Linear(dim * 4, dim)
#out
self.cnn_out = nn.Conv2d(dim,
dim,
kernel_size=1,
stride=1,
padding=0)
def forward(self, q, kv):
#q -> [1, 320, 64, 64]
#kv -> [1, 77, 768]
b, _, h, w = q.shape
res1 = q
#----in----
#维度不变
#[1, 320, 64, 64]
q = self.cnn_in(self.norm_in(q))
#[1, 320, 64, 64] -> [1, 64, 64, 320] -> [1, 4096, 320]
q = q.permute(0, 2, 3, 1).reshape(b, h * w, self.dim)
#----atten----
#维度不变
#[1, 4096, 320]
q = self.atten1(q=self.norm_atten0(q), kv=self.norm_atten0(q)) + q
q = self.atten2(q=self.norm_atten1(q), kv=kv) + q
#----act----
#[1, 4096, 320]
res2 = q
#[1, 4096, 320] -> [1, 4096, 2560]
q = self.fc0(self.norm_act(q))
#1280
d = q.shape[2] // 2
#[1, 4096, 1280] * [1, 4096, 1280] -> [1, 4096, 1280]
q = q[:, :, :d] * self.act(q[:, :, d:])
#[1, 4096, 1280] -> [1, 4096, 320]
q = self.fc1(q) + res2
#----out----
#[1, 4096, 320] -> [1, 64, 64, 320] -> [1, 320, 64, 64]
q = q.reshape(b, h, w, self.dim).permute(0, 3, 1, 2).contiguous()
#维度不变
#[1, 320, 64, 64]
q = self.cnn_out(q) + res1
return q
class DownBlock(torch.nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.trans_block1 = TransformerBlock(dim_out)
self.res_block1 = ResNetBlock(dim_in, dim_out)
self.trans_block2 = TransformerBlock(dim_out)
self.res_block2 = ResNetBlock(dim_out, dim_out)
self.out = nn.Conv2d(dim_out,
dim_out,
kernel_size=3,
stride=2,
padding=1)
def forward(self, vae_out, text_out, time):
outs = []
vae_out = self.res_block1(vae_out, time)
vae_out = self.trans_block1(vae_out, text_out)
outs.append(vae_out)
vae_out = self.res_block2(vae_out, time)
vae_out = self.trans_block2(vae_out, text_out)
outs.append(vae_out)
vae_out = self.out(vae_out)
outs.append(vae_out)
return vae_out, outs
class UpBlock(nn.Module):
def __init__(self, dim_in, dim_out, dim_prev, add_up):
super().__init__()
self.res_block1 = ResNetBlock(dim_out + dim_prev, dim_out)
self.res_block2 = ResNetBlock(dim_out + dim_out, dim_out)
self.res_block3 = ResNetBlock(dim_in + dim_out, dim_out)
self.trans_block1 = TransformerBlock(dim_out)
self.trans_block2 = TransformerBlock(dim_out)
self.trans_block3 = TransformerBlock(dim_out)
self.out = torch.nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(dim_out, dim_out, kernel_size=3, padding=1),
) if add_up else nn.Identity()
def forward(self, vae_out, text_out, time, out_down):
vae_out = self.res_block1(torch.cat([vae_out, out_down.pop()], dim=1), time)
vae_out = self.trans_block1(vae_out, text_out)
vae_out = self.res_block2(torch.cat([vae_out, out_down.pop()], dim=1), time)
vae_out = self.trans_block2(vae_out, text_out)
vae_out = self.res_block3(torch.cat([vae_out, out_down.pop()], dim=1), time)
vae_out = self.trans_block3(vae_out, text_out)
vae_out = self.out(vae_out)
return vae_out
class UNet(torch.nn.Module):
def __init__(self):
super().__init__()
#in
self.in_vae = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
self.in_time = torch.nn.Sequential(
torch.nn.Linear(320, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280),
)
#down
self.down_block1 = DownBlock(320, 320)
self.down_block2 = DownBlock(320, 640)
self.down_block3 = DownBlock(640, 1280)
self.down_res1 = ResNetBlock(1280, 1280)
self.down_res2 = ResNetBlock(1280, 1280)
#mid
self.mid_res1 = ResNetBlock(1280, 1280)
self.mid_trans = TransformerBlock(1280)
self.mid_res2 = ResNetBlock(1280, 1280)
#up
self.up_res1 = ResNetBlock(2560, 1280)
self.up_res2 = ResNetBlock(2560, 1280)
self.up_res3 = ResNetBlock(2560, 1280)
self.up_in = torch.nn.Sequential(
torch.nn.Upsample(scale_factor=2, mode='nearest'),
torch.nn.Conv2d(1280, 1280, kernel_size=3, padding=1),
)
self.up_block1 = UpBlock(640, 1280, 1280, True)
self.up_block2 = UpBlock(320, 640, 1280, True)
self.up_block3 = UpBlock(320, 320, 640, False)
#out
self.out = torch.nn.Sequential(
torch.nn.GroupNorm(num_channels=320, num_groups=32, eps=1e-5),
torch.nn.SiLU(),
torch.nn.Conv2d(320, 4, kernel_size=3, padding=1),
)
# self.load_pretrained()
def get_time_embed(self, t):
#-9.210340371976184 = -math.log(10000)
e = torch.arange(160) * -9.210340371976184 / 160
e = e.exp().to(t.device) * t
#[160+160] -> [320] -> [1, 320]
e = torch.cat([e.cos(), e.sin()]).unsqueeze(dim=0)
return e
def forward(self, vae_out, text_out, time):
#vae_out -> [1, 4, 64, 64]
#out_encoder -> [1, 77, 768]
#time -> [1]
#----in----
#[1, 4, 64, 64] -> [1, 320, 64, 64]
vae_out = self.in_vae(vae_out)
#[1] -> [1, 320]
time = self.get_time_embed(time)
#[1, 320] -> [1, 1280]
time = self.in_time(time)
#----down----
out_down = [vae_out]
#[1, 320, 64, 64],[1, 77, 768],[1, 1280] -> [1, 320, 32, 32]
#out -> [1, 320, 64, 64],[1, 320, 64, 64][1, 320, 32, 32]
vae_out, out = self.down_block1(vae_out=vae_out,
text_out=text_out,
time=time)
out_down.extend(out)
#[1, 320, 32, 32],[1, 77, 768],[1, 1280] -> [1, 640, 16, 16]
#out -> [1, 640, 32, 32],[1, 640, 32, 32],[1, 640, 16, 16]
vae_out, out = self.down_block2(vae_out=vae_out,
text_out=text_out,
time=time)
out_down.extend(out)
#[1, 640, 16, 16],[1, 77, 768],[1, 1280] -> [1, 1280, 8, 8]
#out -> [1, 1280, 16, 16],[1, 1280, 16, 16],[1, 1280, 8, 8]
vae_out, out = self.down_block3(vae_out=vae_out,
text_out=text_out,
time=time)
out_down.extend(out)
#[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
vae_out = self.down_res1(vae_out, time)
out_down.append(vae_out)
#[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
vae_out = self.down_res2(vae_out, time)
out_down.append(vae_out)
#----mid----
#[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
vae_out = self.mid_res1(vae_out, time)
#[1, 1280, 8, 8],[1, 77, 768] -> [1, 1280, 8, 8]
vae_out = self.mid_trans(vae_out, text_out)
#[1, 1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
vae_out = self.mid_res2(vae_out, time)
#----up----
#[1, 1280+1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
vae_out = self.up_res1(torch.cat([vae_out, out_down.pop()], dim=1),
time)
#[1, 1280+1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
vae_out = self.up_res2(torch.cat([vae_out, out_down.pop()], dim=1),
time)
#[1, 1280+1280, 8, 8],[1, 1280] -> [1, 1280, 8, 8]
vae_out = self.up_res3(torch.cat([vae_out, out_down.pop()], dim=1),
time)
#[1, 1280, 8, 8] -> [1, 1280, 16, 16]
vae_out = self.up_in(vae_out)
#[1, 1280, 16, 16],[1, 77, 768],[1, 1280] -> [1, 1280, 32, 32]
#out_down -> [1, 640, 16, 16],[1, 1280, 16, 16],[1, 1280, 16, 16]
vae_out = self.up_block1(vae_out=vae_out,
text_out=text_out,
time=time,
out_down=out_down)
#[1, 1280, 32, 32],[1, 77, 768],[1, 1280] -> [1, 640, 64, 64]
#out_down -> [1, 320, 32, 32],[1, 640, 32, 32],[1, 640, 32, 32]
vae_out = self.up_block2(vae_out=vae_out,
text_out=text_out,
time=time,
out_down=out_down)
#[1, 640, 64, 64],[1, 77, 768],[1, 1280] -> [1, 320, 64, 64]
#out_down -> [1, 320, 64, 64],[1, 320, 64, 64],[1, 320, 64, 64]
vae_out = self.up_block3(vae_out=vae_out,
text_out=text_out,
time=time,
out_down=out_down)
#----out----
#[1, 320, 64, 64] -> [1, 4, 64, 64]
vae_out = self.out(vae_out)
return vae_out
# unet整体参数加载
def load_pretrained(model):
params = UNet2DConditionModel.from_pretrained('./pretrained-params/', subfolder='unet')
#in
model.in_vae.load_state_dict(params.conv_in.state_dict())
model.in_time[0].load_state_dict(params.time_embedding.linear_1.state_dict())
model.in_time[2].load_state_dict(params.time_embedding.linear_2.state_dict())
# down
load_down_block(model.down_block1, params.down_blocks[0])
load_down_block(model.down_block2, params.down_blocks[1])
load_down_block(model.down_block3, params.down_blocks[2])
load_res_block(model.down_res1, params.down_blocks[3].resnets[0])
load_res_block(model.down_res2, params.down_blocks[3].resnets[1])
# mid
load_transformer_block(model.mid_trans, params.mid_block.attentions[0])
load_res_block(model.mid_res1, params.mid_block.resnets[0])
load_res_block(model.mid_res2, params.mid_block.resnets[1])
#up
load_res_block(model.up_res1, params.up_blocks[0].resnets[0])
load_res_block(model.up_res2, params.up_blocks[0].resnets[1])
load_res_block(model.up_res3, params.up_blocks[0].resnets[2])
model.up_in[1].load_state_dict(
params.up_blocks[0].upsamplers[0].conv.state_dict())
load_up_block(model.up_block1, params.up_blocks[1])
load_up_block(model.up_block2, params.up_blocks[2])
load_up_block(model.up_block3, params.up_blocks[3])
#out
model.out[0].load_state_dict(params.conv_norm_out.state_dict())
model.out[2].load_state_dict(params.conv_out.state_dict())
return model
# transformer块参数加载
def load_transformer_block(model: TransformerBlock, param):
model.norm_in.load_state_dict(param.norm.state_dict())
model.cnn_in.load_state_dict(param.proj_in.state_dict())
model.atten1.wq.load_state_dict(
param.transformer_blocks[0].attn1.to_q.state_dict())
model.atten1.wk.load_state_dict(
param.transformer_blocks[0].attn1.to_k.state_dict())
model.atten1.wv.load_state_dict(
param.transformer_blocks[0].attn1.to_v.state_dict())
model.atten1.out_proj.load_state_dict(
param.transformer_blocks[0].attn1.to_out[0].state_dict())
model.atten2.wq.load_state_dict(
param.transformer_blocks[0].attn2.to_q.state_dict())
model.atten2.wk.load_state_dict(
param.transformer_blocks[0].attn2.to_k.state_dict())
model.atten2.wv.load_state_dict(
param.transformer_blocks[0].attn2.to_v.state_dict())
model.atten2.out_proj.load_state_dict(
param.transformer_blocks[0].attn2.to_out[0].state_dict())
model.fc0.load_state_dict(
param.transformer_blocks[0].ff.net[0].proj.state_dict())
model.fc1.load_state_dict(
param.transformer_blocks[0].ff.net[2].state_dict())
model.norm_atten0.load_state_dict(
param.transformer_blocks[0].norm1.state_dict())
model.norm_atten1.load_state_dict(
param.transformer_blocks[0].norm2.state_dict())
model.norm_act.load_state_dict(
param.transformer_blocks[0].norm3.state_dict())
model.cnn_out.load_state_dict(param.proj_out.state_dict())
# resnet 块参数加载
def load_res_block(model: ResNetBlock, param):
model.time_emb[1].load_state_dict(param.time_emb_proj.state_dict())
model.seq1[0].load_state_dict(param.norm1.state_dict())
model.seq1[2].load_state_dict(param.conv1.state_dict())
model.seq2[0].load_state_dict(param.norm2.state_dict())
model.seq2[2].load_state_dict(param.conv2.state_dict())
if isinstance(model.resil, nn.Conv2d):
model.resil.load_state_dict(param.conv_shortcut.state_dict())
# 下采样块参数加载
def load_down_block(model: DownBlock, param):
load_transformer_block(model.trans_block1, param.attentions[0])
load_transformer_block(model.trans_block2, param.attentions[1])
load_res_block(model.res_block1, param.resnets[0])
load_res_block(model.res_block2, param.resnets[1])
model.out.load_state_dict(param.downsamplers[0].conv.state_dict())
# 上采样块参数加载
def load_up_block(model: UpBlock, param):
load_transformer_block(model.trans_block1, param.attentions[0])
load_transformer_block(model.trans_block2, param.attentions[1])
load_transformer_block(model.trans_block3, param.attentions[2])
load_res_block(model.res_block1, param.resnets[0])
load_res_block(model.res_block2, param.resnets[1])
load_res_block(model.res_block3, param.resnets[2])
if isinstance(model.out, nn.Sequential):
model.out[1].load_state_dict(param.upsamplers[0].conv.state_dict())
def unet_pretrained():
unet = UNet()
unet = load_pretrained(unet)
return unet
训练和生成图片
项目结构
工作区目录为 latent_diffusion。在子目录中,data 存放训练数据,output 存放生成的图片,pretrained-params 中的是从参考项目的 hugging face 中下载的训练权重文件和配置文件。
# project.log
latent_diffusion
├── data
│ └── train.parquet
├── output
├── pretrained-params
│ ├── feature_extractor
│ │ └── preprocessor_config.json
│ ├── model_index.json
│ ├── scheduler
│ │ └── scheduler_config.json
│ ├── text_encoder
│ │ ├── config.json
│ │ └── pytorch_model.bin
│ ├── tokenizer
│ │ ├── merges.txt
│ │ ├── special_tokens_map.json
│ │ ├── tokenizer_config.json
│ │ └── vocab.json
│ ├── unet
│ │ ├── config.json
│ │ └── diffusion_pytorch_model.bin
│ └── vae
│ ├── config.json
│ └── diffusion_pytorch_model.bin
├── project.log
├── test.py
├── text_encoder.py
├── train.py
├── train_record.log
├── unet.py
├── utils.py
└── vision_auto_encoder.py
10 directories, 25 files
上述目录结构中的数据文件和权重文件(例如 pytorch_model.bin)请到 参考的 Hugging Face项目中下载 数据链接 和 权重链接。
因为在使用diffusers时不需要联网下载训练参数,有几个文件需要改动,所有 config.json 的键 “_name_or_path” 的值全部修改为对应的子目录绝对位置。
训练
与参考项目的代码相比,增加了混合精度和检查点。训练时并不训练全部的网络,仅训练 UNet 网络部分。文本编码器和VAE的编码器部分参与到了训练过程,但是这两个部分的参数都是冻结的。
'''
训练代码 train.py
'''
import os,torch
from diffusers import DiffusionPipeline
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from text_encoder import text_encoder_pretrained
from vision_auto_encoder import vae_pretrained
from unet import unet_pretrained
from torch.cuda.amp import autocast, GradScaler
from PIL import Image
import io
import warnings
warnings.filterwarnings('ignore')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# diffusion model 噪声生成器
pipeline = DiffusionPipeline.from_pretrained('./pretrained-params/', safety_checker=None, local_files_only=True)
scheduler = pipeline.scheduler
tokenizer = pipeline.tokenizer
del pipeline
print('Device :', device)
print('Scheduler settings: ', scheduler)
print('Tokenizer settings: ', tokenizer)
# 数据处理与加载
dataset = load_dataset('parquet',data_files={'train':'./data/train.parquet'}, split='train')
compose = transforms.Compose([
transforms.Resize((512,512), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop((512,512)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
def pair_text_image_process(data):
# 对图像数据进行增强处理
pixel = [compose(Image.open(io.BytesIO(i['bytes']))) for i in data['image']]
# 文本
text = tokenizer.batch_encode_plus(data['text'], padding='max_length', truncation=True, max_length=77).input_ids
return {'pixel_values': pixel, 'input_ids': text}
dataset = dataset.map(pair_text_image_process, batched=True, num_proc=1, remove_columns=['image', 'text'])
dataset.set_format(type='torch')
def collate_fn(data):
pixel = [i['pixel_values'] for i in data]
text = [i['input_ids'] for i in data]
pixel = torch.stack(pixel).to(device)
text = torch.stack(text).to(device)
return {'pixel': pixel, 'text': text}
loader = DataLoader(dataset, shuffle=True, collate_fn=collate_fn, batch_size=1)
# 模型加载
text_encoder = text_encoder_pretrained()
vision_encoder = vae_pretrained()
unet = unet_pretrained()
text_encoder.eval()
vision_encoder.eval()
unet.train()
text_encoder.to(device)
vision_encoder.to(device)
unet.to(device)
# 优化器, 损失函数, 混合精度
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-5, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-8)
criterion = torch.nn.MSELoss().to(device)
scaler = GradScaler()
# 一个epoch的训练
def train_one_epoch(unet, text_encoder, vision_encoder, train_loader, optimizer, criterion, noise_scheduler, scaler):
loss_epoch = 0.
for step, pair in enumerate(train_loader):
img = pair['pixel']
text = pair['text']
with torch.no_grad():
# 文本编码
text_out = text_encoder(text)
# 图像特征
vision_out = vision_encoder.encoder(img)
vision_out = vision_encoder.sample(vision_out)
vision_out = vision_out * 0.18215
# 添加噪声
noise = torch.randn_like(vision_out)
noise_step = torch.randint(0, 1000, (1,)).long().to(device)
vision_out_noise = noise_scheduler.add_noise(vision_out, noise, noise_step)
with autocast():
noise_pred = unet(vision_out_noise, text_out, noise_step)
loss = criterion(noise_pred, noise)
loss_epoch += loss.item()
# loss.backward()
scaler.scale(loss).backward()
# optimizer.step()
scaler.step(optimizer)
scaler.update()
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
optimizer.zero_grad()
print(f'step: {step} loss: {loss.item():.8f}')
return loss_epoch
# 检查点保存
def save_checkpoint(model, optimizer, epoch, loss, last=False):
state = {
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'loss': loss
}
torch.save(state, f'/hy-tmp/checkpoints/checkpoint_{epoch}.pth.tar')
if last:
torch.save(state, '/hy-tmp/checkpoints/last_checkpoint.pth.tar')
epochs = 100
loss_recorder = []
print('start training ...')
for epoch in range(epochs):
epoch_loss = train_one_epoch(unet, text_encoder, vision_encoder, loader, optimizer, criterion, scheduler, scaler)
save_checkpoint(unet, optimizer, epoch, epoch_loss, True)
loss_recorder.append((epoch, epoch_loss))
loss_recorder = sorted(loss_recorder, key=lambda e:e[-1])
if len(loss_recorder) > 10:
del_check = loss_recorder.pop()
os.remove(f'/hy-tmp/checkpoints/checkpoint_{del_check[0]}.pth.tar')
print(f'epoch: {epoch:03} loss: {epoch_loss:.8f}')
if epoch % 1 == 0:
print('Top 10 checkpoints:')
for i in loss_recorder:
print(i)
print('end training.')
象征性的训练了100个epoch,有些不稳定。跑实验训练所需要的显存为24GB。
文本生成图片
文本生成图片的过程使用到了文本编码器,UNet网络和VAE的解码器部分。UNet接受文本编码器的正负文本输出和并根据迭代步数不断预测噪声分布。最后,UNet的输出给到VAE的解码器生成一张图片。
'''
文本生成图片 test.py
'''
from diffusers import DiffusionPipeline
import torch
from text_encoder import text_encoder_pretrained
from vision_auto_encoder import vae_pretrained
from unet import unet_pretrained
from PIL import Image
import numpy as np
import warnings
warnings.filterwarnings('ignore')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# diffusion model 噪声生成器
pipeline = DiffusionPipeline.from_pretrained('./pretrained-params/', safety_checker=None, local_files_only=True)
scheduler = pipeline.scheduler
tokenizer = pipeline.tokenizer
del pipeline
print('Device :', device)
print('Scheduler settings: ', scheduler)
print('Tokenizer settings: ', tokenizer)
# 模型加载
text_encoder = text_encoder_pretrained()
vision_encoder = vae_pretrained()
unet = unet_pretrained()
text_encoder.eval()
vision_encoder.eval()
unet.eval()
text_encoder.to(device)
vision_encoder.to(device)
unet.to(device)
# 根据文本生成图像
@torch.no_grad()
def generate(text, flag='gen_img'):
# 词编码 [1, 77]
pos = tokenizer(text, padding='max_length', max_length=77,
truncation=True, return_tensors='pt').input_ids.to(device)
neg = tokenizer('', padding='max_length', max_length=77,
truncation=True, return_tensors='pt').input_ids.to(device)
pos_out = text_encoder(pos) # (1, 77, 768)
neg_out = text_encoder(neg) # -
text_out = torch.cat((neg_out, pos_out), dim=0) # (2, 77, 768)
# 全噪声图
vae_out = torch.randn(1,4,64,64, device=device)
# 生成时间步
scheduler.set_timesteps(50, device=device)
for time in scheduler.timesteps:
noise = torch.cat((vae_out, vae_out), dim=0)
noise = scheduler.scale_model_input(noise, time)
# 预测噪声分布
# print('text out', text_out.shape)
pred_noise = unet(vae_out=noise, text_out=text_out, time=time)
# 降噪
pred_noise = pred_noise[0] + 7.5 * (pred_noise[1] - pred_noise[0])
# 继续添加噪声
vae_out = scheduler.step(pred_noise, time, vae_out).prev_sample
# 从压缩图恢复成图片
vae_out = 1/0.18215 * vae_out
image = vision_encoder.decoder(vae_out)
# 转换并保存
image = image.cpu()
image = (image + 1) / 2
image = image.clamp(0, 1)
image = image.permute(0, 2, 3, 1)
image = image.numpy()[0]
image = Image.fromarray(np.uint8(image*255))
image.save(f'./output/{flag}.jpg')
texts = [
'a drawing of a star with a jewel in the center',
'a drawing of a woman in a red cape',
'a drawing of a dragon sitting on its hind legs',
'a drawing of a blue sea turtle holding a rock',
'a blue and white bird with its wings spread',
'a blue and white stuffed animal sitting on top of a white surface',
'a teddy bear sitting on a desk',
]
images = []
for i,text in enumerate(texts):
image = generate(text, f'gen_img{i}')
print(f'text: {text}, finished')