如何理解图像生成(diffusion)中“标准分布的累计分布函数的差分去模拟离散的高斯分布”?

对于如何理解图像生成(diffusion)中“标准分布的累计分布函数的差分去模拟离散的高斯分布”?这个问题的回答如下:
图像生成需要计算整个图片的似然概率的大小,但是通常来说(比如diffusion中求L0)能获得的通常都是图片对应的连续高斯分布(均值和方差),但是图片的像素值是0-255的离散值,因此计算图片的似然需要离散的高斯分布。如何通过连续的分布获取离散的分布有不同的处理方法,其中一种做法的定义是离散的p(x)=连续高斯分布(PDF)概率密度函数中以x为中心一个单位所占的面积大小,此时计算面积大小就可以通过累积分布函数(CDF)去算,具体为离散的p(x)=CDF(x+半个单位)-CDF(x-半个单位),这就是所谓的差分。值得注意的是,累积分布函数通常都是标准正态分布才能近似,因此如果你有的是一个高斯分布(非标准),需要将其归一化后转变为标准正态分布再进行后续处理。

def discretized_gaussian_log_likelihood(x, *, means, log_scales):
    """
    Compute the log-likelihood of a Gaussian distribution discretizing to a
    given image.

    :param x: the target images. It is assumed that this was uint8 values,
              rescaled to the range [-1, 1].
    :param means: the Gaussian mean Tensor.
    :param log_scales: the Gaussian log stddev Tensor.
    :return: a tensor like x of log probabilities (in nats).
    """
    assert x.shape == means.shape == log_scales.shape
    '''
    高斯分布概率密度函数(PDF)和累积分布函数(CDF)
    x是由[0...255]rescaled到[-1,1]的,也就是说相邻的像素值的差值大小为2/255.0,也就是一个单位大小为2/255.0;
    已知x所属高斯分布函数的均值和方差,由于像素值是离散的值,所以需要获得离散的高斯分布函数;
    定义:离散的p(x)等于连续高斯分布概率密度函数中[x-1/255,x+1/255]所占的面积(即以x为中心一个单位内的面积)
    要求这个值就需要获得其累积分布函数,由其累积分布函数的差分得到,离散的p(x)=CDF(x+半个单位)-CDF(x-半个单位);
    由于高斯分布概率密度函数是不可积的,无原函数,因此只能通过近似得到,并且通常是近似标准的高斯分布的累积分布函数
    因此需要将已知的高斯分布转换到标准的高斯分布,再通过标准的高斯分布的累积分布函数的差分获得离散的p(x);
    当x < -0.999时,考虑到这个位置概率较小,所以其离散的值不是以x为中心一个单位内的面积,而是x+半个单位左边的全部面积,即log_cdf_plus
    当x>0.9999时,考虑到这个位置概率较小,所以其离散的值不是以x为中心一个单位内的面积,而是x-半个单位右边的全部面积,即1-cdf_min(log_log_one_minus_cdf_min);
    '''
    centered_x = x - means
    inv_stdv = th.exp(-log_scales)
    plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
    cdf_plus = approx_standard_normal_cdf(plus_in)
    min_in = inv_stdv * (centered_x - 1.0 / 255.0)
    cdf_min = approx_standard_normal_cdf(min_in)
    log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
    log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
    cdf_delta = cdf_plus - cdf_min
    log_probs = th.where(
        x < -0.999,
        log_cdf_plus,
        th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
    )
    assert log_probs.shape == x.shape
    return log_probs

这个函数就是IDDPM中计算L0损失时实现的“标准分布的累计分布函数的差分去模拟离散的高斯分布”函数。

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值