网络模块
首先,我做的是超分辨率任务,所以不需要处理条件的代码。
在我学习IDDPM和DDPM时,对Unet有两个疑惑:
1:时间t怎么用于模型的预测?
2:attention机制时怎么融入进去的?
先回答1,其实很简单,与nlp一样,使用Embedding将t融入模型的预测。
class TimeEmbedding(nn.Module):
def __init__(self, T, d_model, dim):
assert d_model % 2 == 0
super().__init__()
emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
emb = torch.exp(-emb) # [d_model/2]
pos = torch.arange(T).float() # [t]
# emb = pos[:, None] * emb[None, :]
# emb = pos[T,1]*emb[1,d_model/2] = [T,d_model/2]
emb = pos.unsqueeze(1) * emb.unsqueeze(0)
assert list(emb.shape) == [T, d_model // 2]
emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) # 使用stack可以创建一个新维度堆叠cat不能只能拼接
assert list(emb.shape) == [T, d_model // 2, 2]
emb = emb.view(T, d_model)
self.timembedding = nn.Sequential(
nn.Embedding.from_pretrained(emb), # 上面的所有都只是为了预训练权重而已,简单地说就是nn.Embedding(T, d_model)
nn.Linear(d_model, dim),
Swish(),
nn.Linear(dim, dim),
)
def forward(self, t):
# 如果输入是[B],那么输出就是[B,dim]
emb = self.timembedding(t-1)
return emb
这段代码看起来很长,很唬人,其实可以简化为
class TimeEmbedding(nn.Module):
def __init__(self, T, d_model, dim):
assert d_model % 2 == 0
super().__init__()
self.timembedding = nn.Sequential(
nn.Embedding(T, d_model),
nn.Linear(d_model, dim),
Swish(),
nn.Linear(dim, dim),
)
def forward(self, t):
# 如果输入是[B],那么输出就是[B,dim]
emb = self.timembedding(t-1)
return emb
源码中前面一大块代码就是为了提供一个预训练。在TimeEmbedding中,使用nn.Embedding先将步骤T([B])嵌入为[B,d_model],再通过一个全连接层转换为[B,dim]
对于问题二
简单的说,源码中先实现了一个Resblock,每个Resblock中包含两个卷积,第一个负责改变通道数,第二个负责学习特征图内容。我把他做成图就一目了然了。GroupNorm可以用BN直接替代,它与BN互有胜负就不多说了。Resblock的输入为特征图和嵌入后的时间编码,里面的block1(不是bock,这里打错了)仅对特征图,然后与扩展后的时间编码融合后,经过block2。每个block中都有一个卷积。再Resblock的最后可以选择该block中有没有attention机制,如果有,就在该block中增加一个Multi-head self attention(MSA),并且MSA不会改变任何维度的大小。
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
super().__init__()
self.block1 = nn.Sequential(
nn.BatchNorm2d(in_ch), # 对 通道数 分组进行归一化,注意是通道数
Swish(),
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
)
self.temb_proj = nn.Sequential(
# 对时间编码 temb 进行投影的模块,将时间编码的维度 tdim 转换为 out_ch
Swish(),
nn.Linear(tdim, out_ch),
)
self.block2 = nn.Sequential(
nn.BatchNorm2d(out_ch),
Swish(),
nn.Dropout(dropout),
nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
)
if in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
else:
self.shortcut = nn.Identity()
if attn:
self.attn = Attention(out_ch)
else:
self.attn = nn.Identity()
def forward(self, x, temb):
h = self.block1(x)
h += self.temb_proj(temb)[:, :, None, None] # 通过广播相加
h = self.block2(h)
h = h + self.shortcut(x)
h = self.attn(h)
return h
代码非常的简单就不过多的解释
Unet网络
这是我根据源码设计简化的一个Unet网络
先说说与源码的区别,我的代码封装非常的差,也没有解耦,要不同的模型需要更改代码,不过更改起来也很简单。源码可以根据参数不同得到不同的网络结构。根据源码的超参数,每一层都使用了两个Resblock,并且使用了更多的attention机制。
再说我修改的代码,因为设备太差只能缩小模型结构。我只在得出IDDPM需要的预测的线性权重处使用了attention机制,也就是middleblock处,得到预测的线性权重相当于一个二分类网络,如果是DDPM可以直接将其去掉。我相信大伙光看这个结构图就已经懂了,下来上代码,你可以根据你的设备和需要再Unet中修改你的网络结构。
class UNet(nn.Module):
def __init__(self, T=1000, in_ch=3, out_ch=3, dropout=0.1, tdim=512):
# tdim:时间编码的嵌入维度
super(UNet, self).__init__()
# 时间编码
self.time_embedding = TimeEmbedding(T, 128, tdim)
# 编码器部分
self.res1 = ResBlock(3,64,tdim,dropout)
self.pool1 = nn.MaxPool2d(2)
self.res2 = ResBlock(64, 128, tdim, dropout)
self.pool2 = nn.MaxPool2d(2)
self.res3 = ResBlock(128,256, tdim, dropout)
self.pool3 = nn.MaxPool2d(2)
self.res4 = ResBlock(256,512, tdim, dropout)
self.pool4 = nn.MaxPool2d(2)
# 中间部分,引入self-attention机制
self.midblock = ResBlock(512, 1024,tdim,dropout,attn=True)
# 通过中间结果去预测线性权重
self.var_weight = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(1024,512),
nn.Linear(512,1),
nn.Sigmoid()
)
# 解码器部分
self.up1 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.res5 = ResBlock(1024,512,tdim,dropout)
self.up2 = nn.ConvTranspose2d(512,256, 2, stride=2)
self.res6 = ResBlock(512,256, tdim, dropout)
self.up3 = nn.ConvTranspose2d(256,128, 2, stride=2)
self.res7 = ResBlock(256,128, tdim, dropout)
self.up4 = nn.ConvTranspose2d(128,64, 2, stride=2)
self.res8 = ResBlock(128,64, tdim, dropout)
# 输出层
self.out = nn.Conv2d(64, out_ch, 1)
def forward(self, x, t):
temb = self.time_embedding(t)
# 编码器部分
x1 = self.res1(x,temb)
x2 = self.res2(self.pool1(x1),temb)
x3 = self.res3(self.pool2(x2),temb)
x4 = self.res4(self.pool3(x3),temb)
# 中间部分,引入了self-attention机制
out = self.midblock(self.pool4(x4),temb)
var_weight = self.var_weight(out) # 预测的线性权重
# 解码器部分
out = self.up1(out)
out = torch.cat([out, x4], dim=1)
out = self.res5(out,temb)
out = self.up2(out)
out = torch.cat([out, x3], dim=1)
out = self.res6(out,temb)
out = self.up3(out)
out = torch.cat([out, x2], dim=1)
out = self.res7(out,temb)
out = self.up4(out)
out = torch.cat([out, x1], dim=1)
out = self.res8(out,temb)
out = self.out(out)
return out,var_weight.squeeze(1) # 返回预测的噪音和预测的线性权重
Unet.py完整代码
import torch.nn as nn
import torch
import torch.nn.functional as F
import math
class Swish(nn.Module):
'''Swish激活函数'''
def forward(self, x):
return x * torch.sigmoid(x)
class TimeEmbedding(nn.Module):
def __init__(self, T, d_model, dim):
assert d_model % 2 == 0
super().__init__()
emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
emb = torch.exp(-emb) # [d_model/2]
pos = torch.arange(T).float() # [t]
# emb = pos[:, None] * emb[None, :]
# emb = pos[T,1]*emb[1,d_model/2] = [T,d_model/2]
emb = pos.unsqueeze(1) * emb.unsqueeze(0)
assert list(emb.shape) == [T, d_model // 2]
emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) # 使用stack可以创建一个新维度堆叠cat不能只能拼接
assert list(emb.shape) == [T, d_model // 2, 2]
emb = emb.view(T, d_model)
self.timembedding = nn.Sequential(
nn.Embedding.from_pretrained(emb), # 上面的所有都只是为了预训练权重而已,简单地说就是nn.Embedding(T, d_model)
nn.Linear(d_model, dim),
Swish(),
nn.Linear(dim, dim),
)
def forward(self, t):
# 如果输入是[B],那么输出就是[B,dim]
emb = self.timembedding(t-1)
return emb
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
super().__init__()
self.block1 = nn.Sequential(
nn.BatchNorm2d(in_ch), # 对 通道数 分组进行归一化,注意是通道数
Swish(),
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
)
self.temb_proj = nn.Sequential(
# 对时间编码 temb 进行投影的模块,将时间编码的维度 tdim 转换为 out_ch
Swish(),
nn.Linear(tdim, out_ch),
)
self.block2 = nn.Sequential(
nn.BatchNorm2d(out_ch),
Swish(),
nn.Dropout(dropout),
nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
)
if in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
else:
self.shortcut = nn.Identity()
if attn:
self.attn = Attention(out_ch)
else:
self.attn = nn.Identity()
def forward(self, x, temb):
h = self.block1(x)
h += self.temb_proj(temb)[:, :, None, None] # 通过广播相加
h = self.block2(h)
h = h + self.shortcut(x)
h = self.attn(h)
return h
class Attention(nn.Module):
def __init__(self,dim,num_heads=8):
super(Attention, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads # 1024 // 8 = 128
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3) # qkv经过一个linear得到。 768 --> 2304
self.proj = nn.Linear(dim, dim) # 一次新的映射
def forward(self, x):
B, C, H, W = x.shape
N = H*W
x = x.reshape(B,C,-1) # [B,C,N]
x = torch.einsum('xyz->xzy',x) # [B,N,C]
qkv = self.qkv(x) # [B,N,3*C]
qkv = qkv.reshape(B,N,3,self.num_heads,C//self.num_heads) # [B,N,3,8,128]
qkv = torch.einsum('abcde->cadbe',qkv) # [3,B,8,N,128]
q, k, v = qkv[0], qkv[1], qkv[2] # [B,8,N,128]
attn = (q @ torch.einsum('abcd->abdc',k)) * self.scale
# 按行进行softmax
attn = attn.softmax(dim=-1)
x = (attn @ v) # [B,8,N,128]
x = torch.einsum('abcd->acbd',x) # [B,N,8,128]
x = x.reshape(B,N,C) # [B,N,1024]
x = self.proj(x)
x = x.reshape(B,H,W,C) # [B,H,W,C]
x = torch.einsum('abcd->adbc',x)
return x
class UNet(nn.Module):
def __init__(self, T=1000, in_ch=3, out_ch=3, dropout=0.1, tdim=512):
# tdim:时间编码的嵌入维度
super(UNet, self).__init__()
# 时间编码
self.time_embedding = TimeEmbedding(T, 128, tdim)
# 编码器部分
self.res1 = ResBlock(3,64,tdim,dropout)
self.pool1 = nn.MaxPool2d(2)
self.res2 = ResBlock(64, 128, tdim, dropout)
self.pool2 = nn.MaxPool2d(2)
self.res3 = ResBlock(128,256, tdim, dropout)
self.pool3 = nn.MaxPool2d(2)
self.res4 = ResBlock(256,512, tdim, dropout)
self.pool4 = nn.MaxPool2d(2)
# 中间部分,引入self-attention机制
self.midblock = ResBlock(512, 1024,tdim,dropout,attn=True)
# 通过中间结果去预测线性权重
self.var_weight = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(1024,512),
nn.Linear(512,1),
nn.Sigmoid()
)
# 解码器部分
self.up1 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.res5 = ResBlock(1024,512,tdim,dropout)
self.up2 = nn.ConvTranspose2d(512,256, 2, stride=2)
self.res6 = ResBlock(512,256, tdim, dropout)
self.up3 = nn.ConvTranspose2d(256,128, 2, stride=2)
self.res7 = ResBlock(256,128, tdim, dropout)
self.up4 = nn.ConvTranspose2d(128,64, 2, stride=2)
self.res8 = ResBlock(128,64, tdim, dropout)
# 输出层
self.out = nn.Conv2d(64, out_ch, 1)
def forward(self, x, t):
temb = self.time_embedding(t)
# 编码器部分
x1 = self.res1(x,temb)
x2 = self.res2(self.pool1(x1),temb)
x3 = self.res3(self.pool2(x2),temb)
x4 = self.res4(self.pool3(x3),temb)
# 中间部分,引入了self-attention机制
out = self.midblock(self.pool4(x4),temb)
var_weight = self.var_weight(out) # 预测的线性权重
# 解码器部分
out = self.up1(out)
out = torch.cat([out, x4], dim=1)
out = self.res5(out,temb)
out = self.up2(out)
out = torch.cat([out, x3], dim=1)
out = self.res6(out,temb)
out = self.up3(out)
out = torch.cat([out, x2], dim=1)
out = self.res7(out,temb)
out = self.up4(out)
out = torch.cat([out, x1], dim=1)
out = self.res8(out,temb)
out = self.out(out)
return out,var_weight.squeeze(1) # 返回预测的噪音和预测的线性权重