【扩散模型】生成模型中的Residual Self-Attention UNet 以及 DDPM的pytorch代码

参考:
[1] https://github.com/xiaohu2015/nngen/blob/main/models/diffusion_models/ddpm_cifar10.ipynb
[2] https://www.bilibili.com/video/BV1we4y1H7gG/?spm_id_from=333.337.search-card.all.click&vd_source=9e9b4b6471a6e98c3e756ce7f41eb134

1 UNet部分

1.1 SelfAttention

1)自注意力模块可以调用pytorch的 nn.MultiheadAttention(channels, head, batch_first),避免重复造轮子;
2)执行顺序为:

  • 将输入由(B,C,H,W) -> (B,C,H*W) -> (B,H*W,C)
  • 通过LayerNorm模块,得到x_ln
  • 将x_ln作为三个qkv参数传入到多头注意力模块,得到attention_value
  • 将attention_value和原始输入x进行残差连接
  • (可加可不加)再通过前馈神经网络
  • 将attention_value变回(B,C,H,W)
class SelfAttention(nn.Module):
	def __init__(self,channels):
		super().__init__()
		self.channels = channels
		self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
		self.ln = nn.LayerNorm([channels])
		self.ff = nn.Sequential(
			nn.LayerNorm([channels]),
			nn.Linear(channels,channels),
			nn.GELU(),
			nn.Linear(channels,channels)
			)
	
	def forward(self,x):
		B,C,H,W = x.shape
		x = x.reshape(-1,self.channels,H*W).swapaxes(1,2)
		x_ln = self.ln(x)
		attention_value = self.mha(x_ln)
		attention_value = attention_value + x
		attention_value = self.ff(attention_value)+ attention_value
		return attention_value.swapaxes(1,2).view(-1,self.channels,H,W)

测试:

    # here testing MHA
    mha = SelfAttention(32)
    x = torch.rand(3,32,64,64)
    out = mha(x)
    print(x.shape)
  
  # torch.Size([3, 32, 64, 64])

1.2 DoubleConv

相当于UNet中的double conv,只不过这里把一些模块换了,并且新增了residual结构。

class DoubleConv(nn.Module):
	def __init__(self,in_c,out_c,mid_c=None,residual=False):
		super().__init__()
		self.residual = residual
		if mid_c is None:
			mid_c = out_c
		self.double_conv = nn.Sequential(
			nn.Conv2d(in_c,mid_c,kernel_size=3,padding=1),
			nn.GroupNorm(1,mid_c),
			nn.GELU(),
			nn.Conv2d(mid_c,out_c,kernel_size=3,padding=1),
			nn.GroupNorm(1,mid_c)
			)
		if in_c != out_c:
			self.shortcut = nn.Conv2d(in_c,out_c,kernel_size=1)
		else:
			self.shortcut = nn.Identity()
	
	def forward(self,x):
		if self.residual:
			return F.gelu(self.shortcut(x)+self.double_conv(x))
		else:
			return F.gelu(self.double_conv(x))

1.3 Down

down模块其实就是一个maxpooling层,再接两个double_conv层,其中double_conv的维度变化为in_c -> out_c -> out_c;
其次,还有一个timestep_embedding层,既一个激活函数+一个线性层,目的是为了让timestep的维度(B, emb_dim) 和要相加的数据一致(B, out_c, h,w)

class Down(nn.Module):
	def __init__(self,in_c,out_c,emb_dim=256):
		self.maxpool_conv = nn.Sequential(
			nn.MaxPool2d(2) # kernel_size=2, stride default equal to k
			DoubleConv(in_c,out_c,residual=True),
			DoubleConv(out_c,out_c)
			)
		self.emb_layer = nn.Sequential(
			nn.SiLU(),
			nn.Linear(emb_dim,out_c)
			)
	
	def forward(self,x,t):
		x = self.maxpool_conv(x)
		emb = self.emb_layer(t)[:,:,None,None].repeat(1,1,x.shape[-2],x.shape[-1])
		# 扩维后,在最后两维重复h和w次,此时和x的尺寸相同
		return x+emb

1.4 Up

Up模块先进行双线性插值上采样,然后在channel维度进行拼接,之后在进行两次double conv。
同样要有timestep_embeding

class Up(nn.Module):
	def __init__(self,in_c,out_c,emb_dim=256):
		self.up =  nn.UpSample(scale_factor=2,mode='bilinear', align_corner=True)
		self.conv = nn.Sequential(
			nn.Conv2d(in_c,in_c,residual=True),
			nn.Conv2d(in_c,out_c)
			)
		self.emb_layer = nn.Sequential(
			nn.SiLU(),
			nn.Linear(emb_dim,out_c)
			)
	
	def forward(self,x,skip_x, t):
		x = self.up(x)
		x = torch.cat([x,skip_x],dim=1)
		x = self.conv(x)
		emb = self.emb_layer(t)[:,:,None,None].repeat(1,1,x.shape[-2],x.shape[-1])
		return x + emb

1.5 UNet模型

根据常规的UNet模型拼接起来,在每次下采样和上采样之后加上self-attention层

class UNet(nn.Module):
	def __init__(self,in_c, out_c, time_dim=256, device='cuda'):
		super().__init__()
        self.device = device
        self.time_dim = time_dim
        self.inc = DoubleConv(c_in, 64)
        self.down1 = Down(64, 128)
        self.sa1 = SelfAttention(128)
        self.down2 = Down(128, 256)
        self.sa2 = SelfAttention(256)
        self.down3 = Down(256, 512)
        self.sa3 = SelfAttention(512)

        self.bot1 = DoubleConv(512, 512)
        self.bot2 = DoubleConv(512, 512)
        self.bot3 = DoubleConv(512, 256)

        self.up1 = Up(512, 128)
        self.sa4 = SelfAttention(128)
        self.up2 = Up(256, 64)
        self.sa5 = SelfAttention(64)
        self.up3 = Up(128, 64)
        self.sa6 = SelfAttention(64)
        self.outc = nn.Conv2d(64, c_out, kernel_size=1)
	
	def pos_encoding(self,t,channels):
		freq = 1.0/(10000**torch.arange(0,channels,2,device=self.device).float()/channels)
		args = t[:,None].float()*freq[None]
		embedding = torch.cat([torch.sin(args), torch.cos(args)],dim=-1)
		if channels % 2 != 0:
			embedding = torch.cat([embedding,torch.zeros_like(embedding[:,:1])],dim=-1)
		return embeddig
	
	def forward(self,x,t):
		t = self.pos_encoding(t,self.time_dim)
		x1 = self.inc(x)
        x2 = self.down1(x1, t)
        x2 = self.sa1(x2)
        x3 = self.down2(x2, t)
        x3 = self.sa2(x3)
        x4 = self.down3(x3, t)
        x4 = self.sa3(x4)

        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        x = self.up1(x4, x3, t)
        x = self.sa4(x)
        x = self.up2(x, x2, t)
        x = self.sa5(x)
        x = self.up3(x, x1, t)
        x = self.sa6(x)
        output = self.outc(x)
        return output

关于positional_encoding的讲解:

  • freq = 1.0/(10000**torch.arange(0,channels,2,device=self.device).float()/channels)
    这里是从1到 1 1000 0 256 / 256 \frac{1}{10000^{256}/256} 10000256/2561
  • args = t[:,None].float()*freq[None]
    t为1-d向量,所以这里先进行扩维,freq是一个[channels//2]维,最终args为[3,128]
  • embedding = torch.cat([torch.sin(args), torch.cos(args)],dim=-1)
    最后一维进行拼接
  • embedding = torch.cat([embedding,torch.zeros_like(embedding[:,:1])],dim=-1)
    这里是防止维度为奇数的情况,若为奇数则在最后一维补0.
  • 使用方法:
    在每一层residual block之后,使用emb_layer对timestep_embedding进行维度变换,之后加到数据上即可。

2 Diffusion部分以及回顾

2.1 beta_schedule

linear_beta_schedule

    def linear_beta_schedule(self):
        scale = 1000/self.noise_steps
        beta_start = self.beta_start*scale
        beta_end = self.beta_end*scale
        return torch.linspace(beta_start, beta_end, self.noise_steps)

cosine_beta_schedule
公式为: f ( t ) = c o s ( t / T + s 1 + s × π 2 ) 2 α t = f ( t ) / f ( 0 ) β t = 1 − α t α t − 1 f(t)=cos(\frac{t/T+s}{1+s}\times\frac{\pi}{2})^2\\\alpha_t=f(t)/f(0)\\\beta_t=1-\frac{\alpha_t}{\alpha_{t-1}} f(t)=cos(1+st/T+s×2π)2αt=f(t)/f(0)βt=1αt1αt

    def cosine_beta_schedule(self,s=0.008):
        """
        as proposed in Improved ddpm paper;
		"""
        steps = self.noise_steps + 1
        x = torch.linspace(0, self.noise_steps, steps, dtype=torch.float64) # 从0到self.noise_steps
        alphas_cumprod = torch.cos(((x / self.noise_steps) + s) / (1 + s) * math.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) # alpha_cumprod包含了noise_steps+1个值,则alpha_t是第一个到最后一个;alpha_{t-1}是第0个到倒数第二个(第0个为0)
        return torch.clip(betas, 0, 0.999) # 不大于0.999

2.2 初始化

回顾一下DDPM所有的公式:

  1. 前向过程
    x t = α t x t − 1 + 1 − α t ϵ x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t = \sqrt\alpha_tx_{t-1}+\sqrt{1-\alpha_t}\epsilon\\ x_t = \sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon xt=α txt1+1αt ϵxt=αˉt x0+1αˉt ϵ

  2. 后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)的均值和方差
    μ q = α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x ^ θ 1 − α ˉ t = 1 α t x t − 1 − α t 1 − α ˉ t α t ϵ ^ ( x t , t ) \mu_q =\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})x_t+\sqrt{\bar\alpha_{t-1}}(1-\alpha_t)\hat x_\theta}{1-\bar\alpha_t}= \frac{1}{\sqrt{\alpha_t}}x_t-\frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}\sqrt{\alpha_t}}\hat\epsilon(x_t,t) μq=1αˉtαt (1αˉt1)xt+αˉt1 (1αt)x^θ=αt 1xt1αˉt αt 1αtϵ^(xt,t)
    Σ = ( 1 − α t ) ( 1 − α ˉ t − 1 ) 1 − α ˉ t I = β t ( 1 − α ˉ t − 1 ) 1 − α ˉ t I \Sigma=\frac{(1-\alpha_t)(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}I=\frac{\beta_t(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}I Σ=1αˉt(1αt)(1αˉt1)I=1αˉtβt(1αˉt1)I

  3. 每一次采样得到的估计 x ^ 0 \hat x_0 x^0 x t − 1 x_{t-1} xt1
    x ^ 0 = 1 α ˉ t ( x t − 1 − α ˉ t ϵ ) x t − 1 = μ ~ + σ t z \hat x_0 = \frac{1}{\sqrt{\bar\alpha_t}}(x_t-\sqrt{1-\bar\alpha_t}\epsilon)\\ x_{t-1} = \tilde\mu+\sigma_t z x^0=αˉt 1(xt1αˉt ϵ)xt1=μ~+σtz

class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, beta_schedule='linear',device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device
        if beta_schedule == 'linear':
            self.beta = self.linear_beta_schedule().to(device)
        elif beta_schedule == 'cosine':
            self.beta = self.cosine_beta_schedule().to(device)
        else:
            raise ValueError(f'Unknown beta schedule {beta_schedule}')

        # all parameters
        self.alpha = 1. - self.beta 
        self.alpha_hat = torch.cumprod(self.alpha, dim=0) 
        self.alpha_hat_prev = F.pad(self.alpha_hat[:-1],(1,0),value=1.)
        self.sqrt_alpha_hat = torch.sqrt(self.alpha_hat)
        self.sqrt_one_minus_alpha_hat = torch.sqrt(1.-self.alpha_hat)
        self.sqrt_recip_alpha_hat = torch.sqrt(1./self.alpha_hat) # 用于估计x_0,估计x_0后用于计算p(x_{t-1}|x_t) 均值
        self.sqrt_recip_minus_alpha_hat = torch.sqrt(1./self.alpha_hat-1) 
        self.posterior_variance = (self.beta*(1.-self.alpha_hat_prev)/(1.-self.alpha_hat)) # 用于计算p(x_{t-1}|x_t)的方差
        self.posterior_mean_coef1 = (self.beta * torch.sqrt(self.alpha_hat_prev) / (1.0 - self.alphas_hat)) # 用于计算p(x_{t-1}|x_t)的均值
        self.posterior_mean_coef2 = ((1.0 - self.alphas_hat_prev)* torch.sqrt(self.alphas)/ (1.0 - self.alphas_hat))

2.3 提取数组中的对应timestep的值

    def _extract(self,arr,t,x_shape):
        # 根据timestep t从arr中提取对应元素并变形为x_shape
        bs = x_shape[0]
        out = arr.to(t.device).gather(0,t).float()
        out = out.reshape(bs,*((1,)*(len(x_shape)-1))) # reshape为(bs,1,1,1)
        return out

2.4 从 x 0 x_0 x0提取 x t x_t xt

根据公式 x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t = \sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon xt=αˉt x0+1αˉt ϵ
模型训练首先要根据随机采样的t和 x 0 x_0 x0来得到加噪后的 x t x_t xt以及 n o i s e noise noise,所以返回两个值。

    def q_sample(self, x, t, noise=None):
        # q(x_t|x_0)
        if noise is None:
            Ɛ = torch.randn_like(x)
        sqrt_alpha_hat = self._extract(self.sqrt_alpha_hat,t,x.shape)
        sqrt_one_minus_alpha_hat = self._extract(self.sqrt_one_minus_alpha_hat,t,x.shape)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

2.5 真实后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)的均值和方差

参考上面的公式。
实际上我们将 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt)的分布(也就是模型拟合分布)也设为类似形式,然后将模型估计出来的重建 x 0 x_0 x0连同 x t x_t xt丢进这个函数来,预测出 p p p的均值和方差。重建 x 0 x_0 x0怎么来呢?模拟预测噪声,然后根据 x t x_t xt和预测噪声得来。

    def q_posterior_mean_variance(self,x,x_t,t):
        # calculate mean and variance of q(x_{t-1}|x_t,x_0), we send parameters x0 and x_t into this function
        # in fact we use this function to predict p(x_{t-1}|x_t)'s mean and variance by sending x_t, \hat x_0, t
        posterior_mean =  (
            self._extract(self.posterior_mean_coef1,t,x.shape) * x 
            + self._extract(self.posterior_mean_coef2,t,x.shape) * x_t
        )
        posterior_variance = (self.posterior_variance,t,x.shape)
        return posterior_mean, posterior_variance

2.6 估计重建 x 0 x_0 x0

参考上面公式,根据 x t x_t xt和预测出的噪声pred_noise来估计,相当于 x t − ϵ p r e d x_t - \epsilon_{pred} xtϵpred

def estimate_x0_from_noise(self,x_t,t,noise):
        # \hat x_0
        return (self._extract(self.sqrt_recip_alpha_hat,t,x_t.shape)*x_t + self._extract(self.sqrt_recip_minus_alpha_hat,t,x_t.shape)*noise)

2.7 计算 p θ p_\theta pθ的均值和方差

首先通过 x t , t x_t,t xt,t预测噪声,然后估计出重建x0,将值裁剪到(-1,1),然后去估计均值和方差

    def p_mean_variance(self,model,x_t,t,clip_denoised=True):
        pred_noise = model(x_t,t)
        x_recon = self.estimate_x0_from_noise(x_t,t,pred_noise)
        if clip_denoised:
            x_recon = torch.clamp(x_recon,min=-1.,max=1.)
        p_mean,p_var = self.q_posterior_mean_variance(x_recon,x_t,t)
        return p_mean,p_var

2.8 采样

采样就是 x t − 1 = μ + σ t z x_{t-1}= \mu+\sigma_t z xt1=μ+σtz,这个 σ t \sigma_t σt是固定的,z是随机采样的,并且当t=0的时候,也就是最后一步不加噪声。 loop函数采样从noise_step到0。

    def p_sample(self, model, x_t, t, clip_denoised=True):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with torch.no_grad():
            p_mean,p_var = self.p_mean_variance(model,x_t,t,clip_denoised=clip_denoised)
            noise = torch.randn_like(x_t)
            nonzero_mask = ((t!=0).float().view(-1,*([1]*len(x_t.shape)-1))) # 当t!=0时为1,否则为0
            pred_img = p_mean + nonzero_mask*(torch.sqrt(p_var))*noise
        return pred_img
    
    def p_sample_loop(self,model,shape):
        model.eval()
        with torch.no_grad():
            bs = shape[0]
            device = next(model.parameters()).to(device)
            img = torch.randn(shape,device=device)
            imgs = []
            for i in tqdm(reversed(range(0,self.noise_steps)),desc='sampling loop time step',total=self.noise_steps):
                img = self.p_sample(model,img,torch.full((bs,),i,device=device,dtype=torch.long)) # 从T到0
                imgs.append(img)
        return imgs
    
    @torch.no_grad()
    def sample(self,model,img_size,bs=8,channels=3):
        return self.p_sample_loop(model,(bs,channels,img_size,img_size))

3 训练部分

def train(args):
    setup_logging(args.run_name)
    device = args.device
    dataloader = get_data(args)
    model = UNet().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=args.image_size, device=device)
    logger = SummaryWriter(os.path.join("runs", args.run_name))
    l = len(dataloader)

    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, (images, _) in enumerate(pbar):
            images = images.to(device)
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            x_t, noise = diffusion.q_sample(images, t)
            predicted_noise = model(x_t, t)
            loss = mse(noise, predicted_noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_postfix(MSE=loss.item())
            logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)

        sampled_images = diffusion.sample(model, n=images.shape[0])
        save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))
        torch.save(model.state_dict(), os.path.join("models", args.run_name, f"ckpt.pt"))


def launch():
    import argparse
    parser = argparse.ArgumentParser()
    args = parser.parse_args()
    args.run_name = "DDPM_Uncondtional"
    args.epochs = 500
    args.batch_size = 12
    args.image_size = 64
    args.dataset_path = r"C:\Users\dome\datasets\landscape_img_folder"
    args.device = "cuda"
    args.lr = 3e-4
    train(args)


if __name__ == '__main__':
    launch()

4 其他实验

4.1 加噪过程

import ddpm 
from PIL import Image
from torchvision import transforms
import torch
import matplotlib.pyplot as plt
import numpy as np

image = Image.open("giraffe.jpg")

image_size = 128
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])


x_start = transform(image).unsqueeze(0)

diffusion_linear = ddpm.Diffusion(noise_steps=500)
diffusion_cosine = ddpm.Diffusion(noise_steps=500,beta_schedule='cosine')

plt.figure(figsize=(16, 8))
for idx, t in enumerate([0, 50, 100, 200, 499]): 
    x_noisy,_ = diffusion_linear.q_sample(x_start, t=torch.tensor([t])) # 使用q_sample去生成x_t
    x_noisy2,_ = diffusion_cosine.q_sample(x_start,t=torch.tensor([t])) # [1,3,128,128]
    noisy_image = (x_noisy.squeeze().permute(1, 2, 0) + 1) * 127.5  # 我们的x_t被裁剪到(-1,1),所以+1后乘以127.5
    noisy_img2 = (x_noisy2.squeeze().permute(1,2,0)+1)*127.5 # # [128,128,3] -> (0,2) 
    noisy_image = noisy_image.numpy().astype(np.uint8)
    noisy_img2 = noisy_img2.numpy().astype(np.uint8)
    plt.subplot(2, 5, 1 + idx)
    plt.imshow(noisy_image)
    plt.axis("off")
    plt.title(f"t={t}")
    plt.subplot(2, 5, 6+idx)
    plt.imshow(noisy_img2)
    plt.axis('off')
plt.figtext(0.5, 0.95, 'Linear Beta Schedule', ha='center', fontsize=16)  # 在第一行上方添加大标题
plt.figtext(0.5, 0.48, 'Cosine Beta Schedule', ha='center', fontsize=16)  # 在第二行上方添加大标题
plt.savefig('temp_img/add_noise_process.png')

在这里插入图片描述

4.2 多GPU分布式代码

4.3 去噪过程

忘记了十万八千次的知识点(随笔记)

  1. viewreshape的区别
    https://zhuanlan.zhihu.com/p/593664378
  • 21
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值