文章目录
1、UNet网络结构
UNet网络的总体框架如下,右边是UNet网络的整体框架,左边是residual网络和attention网络,
下面是UNet网络的详解结构图,左边进行有规律地残差、下采样、attention,右边也是有规律地残差、上采样、attention,相关的代码在图中给出,
1.1 residual网络和attention网络的细节
熟悉CNN的同学应该能看懂下图中的大部分过程。其中的 t 是时间从0到1000的随机值,假如是888,经过Positional Embedding输出长度是128的向量,下面再经过全连接层和silu层等,下面会详细讲解Positional Embedding、residual网络和attention网络,
1.2 t 的作用
1、和原图像一起,计算出 t 时刻的图像 x t = 1 − α t ‾ ϵ + α t ‾ x 0 x_t=\sqrt{1-\overline{\alpha_t}}\epsilon+\sqrt{\overline{\alpha_t}}x_0 xt=1−αtϵ+αtx0
2、将 t 进行编码,编码后,加到模型中,使模型学习到当前在哪个时刻
1.3 DDPM 中的 Positional Embedding 的使用
左图是Transformer的Positional Embedding,行索引代表第几个单词,列索引代表每个单词的特征向量,右图是DDPM的Positional Embedding,DDPM的Positional Embedding和Transformer的Positional Embedding的区别是DDPM的Positional Embedding并不是给每个词位置编码的,只需要在1000行中随机取出一行就可以了;另一个区别是DDPM的Positional Embedding并没有按照奇数位和偶数位进行拼接,而是按照前后的sin和cos进行拼接的,虽然拼接方式不同,但是最终的效果是一样的。如下图所示,
位置编码只要能保证每一行的唯一性,以及每一行和其他行的关系性就可以了。
1.4 DDPM 中的 Positional Embedding 代码
代码:
class PositionalEmbedding(nn.Module):
__doc__ = r"""..."""
def init (self, dim, scale=1.0):
super().__init__()
assert dim % 2 == 0
self.dim = dim # 特征向量
self.scale = scale # 正弦函数和余弦函数的周期不做调整
def forward(self, x): # x:表示t,从0-1000中随机出来的一个数值,因为设置batch-size=2,所以假设x:tensor([645,958])
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / half_dim
emb = torch.exp(torch.arange(half_dim, device=device) * - emb)
emb = torch.outer(x * self.scale, emb)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
代码解释:
下图中的 e m b 2 × 64 emb_{2\times64} emb2×64中的2表示batch-size等于2,
使用位置:
1.5 residual block
原代码:
class ResidualBlock(nn.Module):
__doc__ = r"""Applies two conv blocks with resudual connection. Adds time and class conditioning by adding bias after first convolution.
Input:
x: tensor of shape (N, in_channels, H, W)
time_emb: time embedding tensor of shape (N, time_emb_dim) or None if the block doesn't use time conditioning
y: classes tensor of shape (N) or None if the block doesn't use class conditioning
Output:
tensor of shape (N, out_channels, H, W)
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
time_emb_dim (int or None): time embedding dimension or None if the block doesn't use time conditioning. Default: None
num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None
activation (function): activation function. Default: torch.nn.functional.relu
norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"
num_groups (int): number of groups used in group normalization. Default: 32
use_attention (bool): if True applies AttentionBlock to the output. Default: False
"""
def __init__(
self,
in_channels,
out_channels,
dropout,
time_emb_dim=None,
num_classes=None,
activation=F.relu,
norm="gn",