diffusion model has two process to image:
choose a fixed forward diffusion q, adding Gauss noise to image until pure noise;
a diffusion process to be learned to decrease noise p;
我们从真实未知和可能复杂的数据分布中随机抽取一个样本
我们均匀地采样1 和T 之间的噪声水平t 即是随机时间步长
我们从高斯分布中采样一些噪声,并使用上面定义的属性在t时间步上破坏输入
神经网络被训练以基于损坏的图像x_t 来预测这种噪声,即基于已知的时间表x_t施加的噪声
Unet 模型首先对输入进行下采样(空间分辨率变小),之后上采样
class SinusodialPositionEmbeedings(nn.Cell):
def __init__(self, dim):
super().__init__()
self.dim = dim
half_dim = self.dim //2
emb = math.log(10000) / (half_dim - 1)
emb = np.exp(np.arange(half_dim ) * -emb)
self.emb = Tensor(emb, ms.float32)
def construct(self, x):
emb = x[:,None] * self.emb[None, :]
emb = ops.concat((ops.sin(emb),ops.cos(emb)), axis = -1)
return emb
(batch_size , 1 )----> (batch_size, dim)
NOW we use convNeXT block as the same func of resnet.
class Block(nn.Cell):
def __init__(self, dim, dim_out, groups = 1):
super().__init__()
self.proj = nn.Conv2d(dim , dun_out, 3, pad_mode = 'pad', padding = 1)
self.proj = c(dim, dim_out, 3, padding = 1, pad_mode = "pad")
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def construct(self, x, scale_shift= None):
x = self.proj(x)
y = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x*(scale + 1) + shift
x = self.act(x)
return x
class ConvNextBlock(nn.Cell):
def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):
super().__init__()
self.mlp = (
nn.SequentialCell(nn.GELU(),nn.Dense(time_emb_dim, dim))
if exists(time_emb_dim)
else None
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group = dim, pad_mode = "pad")
self.net = nn.SequentialCell(
nn.GroupNorm(1,dim) if norm else nn.Identity()
nn.Conv2d(dim, dim_out* mult, 3, padding = 1, pad_mode = "pad")
nn.GELU()
nn.GroupNorm(1, dim_out*mult)
nn.Conv2d(dim_out * mult, dim_out, 3, padding = 1, pad_mode = "pad")
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def construct(self, x, time_emb = None):
h = self.ds_conv(x)
if exists(self.mlp) and exists(time_emb):
assert exists(time_emb), "time embedding must be passed in"
condition = self.mlp(time_emb)
condition = condition.expand_dims(-1).expand_dims(-1)
h = h + condition
h = self.net(h)
return h + self.res_conv(x)
then Unet is used for model constructing.
首先, 卷积层应用在噪声图像上,计算噪声水平的位置
接下来,应用一系列的下采样,每个下采样 = 2x ConvNeXT + groupnorm + attention + res connection + downsampling
网络中间应用Resnet 或 ConvNeXT block 与 attention 交织
下面应用一系列的上采样, 每个上采样由2个Resnet 和 Groupnorm + attention + res connection + upsampling
最后 应用Resnet or ConvNeXT , 最后CONV
class Unet(nn.Cell):
def __init__(
self,
dim,
init_dim = None,
out_dim = None,
dim_mults = (1,2,4,8),
channels = 3,
with_time_emb =True,
convnext_mult = 2,
):
super().__init__()
self.channels = channels
init_dim = default(init_dim, dim//3 *2)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding = 3, pad_mode = "pad", has_bias = True)
dims = [init_dim, *map(lambda m:dim*m, dim_results)]
in_out = list(zip(dims[:-1],dims[1:]))
block_klass = partial(ConvNextBlock, mult = convnext_mult)
if with_time_emb:
time_dim = dim *4
self.time_mlp = nn.SequentialCell(
SinusoidalPositionEmbeddings(dim),
nn.Dense(dim, time_dim),
nn.GELU(),
nn.Dense(time_dim, time_dim),
)
else:
time_dim = None
self.time_mlp = None
self.downs = nn.CellList([])
self.ups = nn.CellList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in emuerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.dowms.append(
nn.CellList(
[
block_klass(dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out, dim_out, time_emb_dim = time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity(),
]
)
)
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim. mid_dim, time_emb_dim = tiem_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(
nn.CellList(
[
block_klass(dim_out *2, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
Residual(PreNorm(dim_in,LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity()
]
)
)
out_dim = default(out_dim, channels)
self.final_conv = nn.SequentialCell(
block_klass(dim,dim), nn.Conv2d(dim,out_dim, 1)
)
def construct(self, x, time):
x = self.init_conv(x)
t = self.time_mlp(time ) if exists(self.time_mlp) else None
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x,t)
x = block2(x,t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x,t)
x = self.mid_attn(x)
x = self.mid_block2(x,t)
len_h = len(h)
for block1 , block2 ,attn, upsample in self.ups:
x = ops.concat((x, h[len_h]),1)
len_h -= 1
x = block1(x)
x = block2(x)
x = attn(x)
x = upsample(x)
return self.final_conv(x)
Above is the construction of Unet.
Here we continue to discuss q and p in diffusion.
we def a time schedule:
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return np.linspace(beta_start, beta_end, timesteps).astype(np.float32)
timesteps = 200
betas = linear_beta_schedule(timesteps = timesteps)
alpha = 1 - betas
alphas_cumprod = np.cumprod(alphas, axis = 0)
alphas_cumprod_prev = np.pad(alphas_cumprod[:-1], (1,0), constant_values = 1)
sqrt_recip_alphas = Tensor(np.sqrt(1. / alphas))
sqrt_alphas_cumprod =Tensor(np.sqrt(alphas_cumprod))
sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1 - alphas_cumprod))
posterior_varience = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
p2_loss_weight = (1 + alphas_cumprod) / (1 - alphas_cumprod) ** -0.
p2_loss_weight = Tensor(p2_loss_weight)
def extract(a,t, x_shape):
b =t.shape[0]
out = Tensor(a).gather(t, -1)
return out.reshape(b, *((1,)*(len(x_shape) - 1)))
some aux func:
def randn_like(x, dtype = None):
if dtype is None:
dtyper = x.dtype
res = ops.standard_normal(x.shape).astype(dtype)
return res
def randn(shape, dtype = None):
if dtype is None:
dtype = ms.float32
res = ops.standard_normal(shape).astype(dtype)
return res
then forward diffusion:
def q_sample(x_start, t, noise = None):
if noise is None:
noise = randn_like(x_start)
return (extract(sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(sqrt_one_minus_alphas_cumprod,t, x_start.shape) * noise)
def get_noisy_image(x_start,t):
x_noisy = q_sample(x_start, t = t)
noisy_image = compose(reverse_transform, x_noisy[0])
return noisy_image
p_loss:
def p_losses(unet_model, x_start, t, noise = None):
if noise is None:
noise = randn_like(x_start)
x_noisy = q_sample(x_start = x_start, t =t, noise = noise)
predicted_noise = unet_model(x_noisy,t)
loss = nn.SmoothL1Loss()(noise, predicted_noise)
loss = loss.reshape(loss.shape[0], -1)
loss = loss*extract(p2_loss_weight, t, loss.shape)
return loss.mean()
def p_sample(model, x, t, t_index):
betas_t = extract(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod,t,x.shape
)
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x,t) / sqrt_one_minus_alphas_cumprod_t)
if t_index == 0:
return model_mean
posterior_varience_t = extract(posterior_varience, t, x.shape)
noise = randn_like(x)
return model_mean + ops.sqrt(posterior_varience_t) * noise
def p_sample_loop(model, shape):
b = shape[0]
img = randn(shape, dtype=None)
imgs = []
for i in tqdm(reversed(range(0, timesteps)),desc = "sampling loop time step", total = timesteps):
img = p_sample(model, img, ms.numpy.full((b,),i,dtype = mstype.int32),i)
imgs.append(img.asnumpy())
return imgs
def sample(model, image_size, batch_size= 16, channels=3):
return p_sample_loop(model, shape=(batch_size, channels, image_size,image_size))