昇思25天学习打卡营第12天|Diffusion

diffusion model has two process to image:

        choose a fixed forward diffusion q, adding Gauss noise to image until pure noise;

        a diffusion process to be learned to decrease noise p;

我们从真实未知和可能复杂的数据分布中随机抽取一个样本

我们均匀地采样1 和T 之间的噪声水平t 即是随机时间步长

我们从高斯分布中采样一些噪声,并使用上面定义的属性在t时间步上破坏输入

神经网络被训练以基于损坏的图像x_t 来预测这种噪声,即基于已知的时间表x_t施加的噪声

Unet 模型首先对输入进行下采样(空间分辨率变小),之后上采样

class SinusodialPositionEmbeedings(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        half_dim = self.dim //2
        emb = math.log(10000) / (half_dim - 1)
        emb = np.exp(np.arange(half_dim ) * -emb)
        self.emb = Tensor(emb, ms.float32)
    def construct(self, x):
        emb = x[:,None] * self.emb[None, :]
        emb = ops.concat((ops.sin(emb),ops.cos(emb)), axis = -1)
        return emb

(batch_size , 1 )---->  (batch_size, dim)

NOW we use convNeXT block as the same func of resnet.

class Block(nn.Cell):
    def __init__(self, dim, dim_out, groups = 1):
        super().__init__()
        self.proj = nn.Conv2d(dim , dun_out, 3, pad_mode = 'pad', padding = 1)
        self.proj = c(dim, dim_out, 3, padding = 1, pad_mode = "pad")
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()
    def construct(self, x, scale_shift= None):
        x = self.proj(x)
        y = self.norm(x)
        if exists(scale_shift):
            scale, shift = scale_shift 
            x  = x*(scale + 1) + shift
        x = self.act(x)
        return x
class ConvNextBlock(nn.Cell):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):
        super().__init__()
        self.mlp = (
            nn.SequentialCell(nn.GELU(),nn.Dense(time_emb_dim, dim))
            if exists(time_emb_dim)
            else None
            )
        self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group = dim, pad_mode = "pad")
        self.net = nn.SequentialCell(
            nn.GroupNorm(1,dim) if norm else nn.Identity()
            nn.Conv2d(dim, dim_out* mult, 3, padding = 1, pad_mode = "pad")
            nn.GELU()
            nn.GroupNorm(1, dim_out*mult)
            nn.Conv2d(dim_out * mult, dim_out, 3, padding = 1, pad_mode = "pad")
)
    self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
    def construct(self, x, time_emb = None):
        h = self.ds_conv(x)
        if exists(self.mlp) and exists(time_emb):
            assert exists(time_emb), "time embedding must be passed in"
            condition = self.mlp(time_emb)
            condition = condition.expand_dims(-1).expand_dims(-1)
            h = h + condition
    h = self.net(h)
    return h + self.res_conv(x)

then Unet  is used for model constructing.

首先, 卷积层应用在噪声图像上,计算噪声水平的位置

接下来,应用一系列的下采样,每个下采样 = 2x ConvNeXT + groupnorm + attention + res connection + downsampling

网络中间应用Resnet 或 ConvNeXT block 与 attention 交织

下面应用一系列的上采样, 每个上采样由2个Resnet 和 Groupnorm  + attention + res connection + upsampling

最后 应用Resnet or ConvNeXT , 最后CONV

class Unet(nn.Cell):
    def __init__(
            self, 
            dim,
            init_dim = None,
            out_dim = None, 
            dim_mults = (1,2,4,8),
            channels = 3,
            with_time_emb  =True,
            convnext_mult = 2,
            ):
        super().__init__()
        self.channels = channels
        init_dim = default(init_dim, dim//3 *2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding = 3, pad_mode = "pad", has_bias = True)
        dims = [init_dim, *map(lambda m:dim*m, dim_results)]
        in_out = list(zip(dims[:-1],dims[1:]))
        block_klass = partial(ConvNextBlock, mult = convnext_mult)
        if with_time_emb:
            time_dim = dim *4
            self.time_mlp = nn.SequentialCell(
                SinusoidalPositionEmbeddings(dim),
                nn.Dense(dim, time_dim),
                nn.GELU(),
                nn.Dense(time_dim, time_dim),
)
        else:
            time_dim = None
            self.time_mlp = None
        self.downs = nn.CellList([])
        self.ups = nn.CellList([])
        num_resolutions = len(in_out)
        for ind, (dim_in, dim_out) in emuerate(in_out):
            is_last = ind >= (num_resolutions - 1)
            self.dowms.append(
                nn.CellList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim = time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim = time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Downsample(dim_out) if not is_last else nn.Identity(),
]
)
)
        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim. mid_dim, time_emb_dim = tiem_dim)
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)
            self.ups.append(
                nn.CellList(
                    [
                        block_klass(dim_out *2, dim_in, time_emb_dim = time_dim),

                        block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                        Residual(PreNorm(dim_in,LinearAttention(dim_in))),
                        Upsample(dim_in) if not is_last else nn.Identity()

]                    
)            
)
        out_dim = default(out_dim, channels)
        self.final_conv = nn.SequentialCell(
            block_klass(dim,dim), nn.Conv2d(dim,out_dim, 1)
)




def construct(self, x, time):
    x = self.init_conv(x)
    t = self.time_mlp(time ) if exists(self.time_mlp) else None
    h = []
    for block1, block2, attn, downsample in self.downs:
        x = block1(x,t)
        x = block2(x,t)
        x = attn(x)
        h.append(x)
        x = downsample(x)
    x = self.mid_block1(x,t)
    x = self.mid_attn(x)
    x = self.mid_block2(x,t)
    len_h = len(h)
    for block1 , block2 ,attn, upsample in self.ups:
        x = ops.concat((x, h[len_h]),1)
        len_h -= 1
        x = block1(x)
        x = block2(x)
        x = attn(x)
        x = upsample(x)
    return self.final_conv(x)

Above is the construction of Unet.

Here we continue to discuss q and p in diffusion.

we def a time schedule:

   

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return np.linspace(beta_start, beta_end, timesteps).astype(np.float32)

timesteps = 200
betas = linear_beta_schedule(timesteps = timesteps)

alpha = 1 - betas
alphas_cumprod = np.cumprod(alphas, axis = 0)
alphas_cumprod_prev = np.pad(alphas_cumprod[:-1], (1,0), constant_values = 1)
sqrt_recip_alphas = Tensor(np.sqrt(1. / alphas))
sqrt_alphas_cumprod =Tensor(np.sqrt(alphas_cumprod))
sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1 - alphas_cumprod))
posterior_varience = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
p2_loss_weight = (1 + alphas_cumprod) / (1 - alphas_cumprod) ** -0.
p2_loss_weight = Tensor(p2_loss_weight)

def extract(a,t, x_shape):
    b =t.shape[0]
    out = Tensor(a).gather(t, -1)
    return out.reshape(b, *((1,)*(len(x_shape) - 1)))

some aux func:

        

def randn_like(x, dtype = None):
    if dtype is None:
        dtyper = x.dtype
    res = ops.standard_normal(x.shape).astype(dtype)
    return res
def randn(shape, dtype = None):
    if dtype is None:
        dtype = ms.float32
    res = ops.standard_normal(shape).astype(dtype)
    return res

then forward diffusion:

def q_sample(x_start, t, noise = None):
    if noise is None:
        noise = randn_like(x_start)
    return (extract(sqrt_alphas_cumprod, t, x_start.shape) * x_start + 
            extract(sqrt_one_minus_alphas_cumprod,t, x_start.shape) * noise)
def get_noisy_image(x_start,t):
    x_noisy = q_sample(x_start, t = t)
    noisy_image = compose(reverse_transform, x_noisy[0])
    return noisy_image 

p_loss:

def p_losses(unet_model, x_start, t, noise = None):
    if noise is None:
        noise = randn_like(x_start)
    x_noisy = q_sample(x_start = x_start, t =t, noise = noise)
    predicted_noise = unet_model(x_noisy,t)
    
    loss = nn.SmoothL1Loss()(noise, predicted_noise)
    loss = loss.reshape(loss.shape[0], -1)
    loss = loss*extract(p2_loss_weight, t, loss.shape)
    return loss.mean()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod,t,x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x,t) / sqrt_one_minus_alphas_cumprod_t)
    if t_index == 0:
        return model_mean
    posterior_varience_t = extract(posterior_varience, t, x.shape)
    noise = randn_like(x)
    return model_mean + ops.sqrt(posterior_varience_t) * noise
def p_sample_loop(model, shape):
    b = shape[0]
    img = randn(shape, dtype=None)
    imgs = []
    for i in tqdm(reversed(range(0, timesteps)),desc = "sampling loop time step", total = timesteps):
        img = p_sample(model, img, ms.numpy.full((b,),i,dtype = mstype.int32),i)
        imgs.append(img.asnumpy())
    return imgs
def sample(model, image_size, batch_size= 16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size,image_size))

  • 5
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值