Flow Model

Flow Model

无监督,生成模型

动机

  1. 拟合连续分布
  2. 给出 x x x可以计算 p ( x ) p(x) p(x)
  3. 从拟合到的分布中采样得到接近于真实的 x x x
  4. 学习到latent representation

方法概要

离散分布由于取值往往是有限个,所以总能在得到每个取值的logits之后做softmax保证得到的概率恒正且和为1(满足分布的条件),但是连续分布无限个取值不能做softmax,所以将其映射到z上并使z满足一个设计好的先验分布(感觉有点像连续版本的softmax),同样用极大似然的方法估计参数。
在这里插入图片描述
在做极大似然的时候,需要注意我们最后目标是 p ( x ) p(x) p(x),但是我们模型得到的是 p ( z ) p(z) p(z),所以存在一个 p ( x ) p(x) p(x) p ( z ) p(z) p(z)之间的转换
在这里插入图片描述在这里插入图片描述
换元的f要求可微且可逆
可逆给出一个intuition的解释
∫ − ∞ + ∞ p ( z ) d z = ∫ − ∞ + ∞ p ( x ) d x ∫ − ∞ + ∞ p ( f ( x ) ) d f ( x ) = ∫ − ∞ + ∞ p ( x ) d x ∫ − ∞ + ∞ p ( f ( x ) f ′ ( x ) d x = ∫ − ∞ + ∞ p ( x ) d x p ( f ( x ) ) f ′ ( x ) = p ( x ) \int_{-\infty}^{+\infty}p(z)dz = \int_{-\infty}^{+\infty}p(x)dx\\ \int_{-\infty}^{+\infty}p(f(x))df(x) = \int_{-\infty}^{+\infty}p(x)dx\\ \int_{-\infty}^{+\infty}p(f(x)f'(x)dx = \int_{-\infty}^{+\infty}p(x)dx\\ p(f(x))f'(x)=p(x) +p(z)dz=+p(x)dx+p(f(x))df(x)=+p(x)dx+p(f(x)f(x)dx=+p(x)dxp(f(x))f(x)=p(x)由于概率恒正,所以需要 f ′ ( x ) > 0 f'(x)>0 f(x)>0恒成立,也就需要 f f f可逆
由于使用概率分布的CDF巨有可逆可微的性质,而且导数恰为PDF方便计算,所以通常用CDF作为 f f f

1-D

在这里插入图片描述
从一个连续分布中采样,红色表示随机变量 x x x的取值,粉色表示该值取到的次数(注意这里的柱状图表示在一定范围内的采样数之和,不然没法画图,取到多次像0.124231这样的离散值的概率很小),黄色表示经过 f f f映射后 z z z的取值,粉黄图表示的就是将训练集的 x x x映射到 z z z后的分布情况。映射函数 f f f通常为混合高斯或混合logistics的CDF。
在这里插入图片描述

2-D

在这里插入图片描述在这里插入图片描述在这里插入图片描述
类似Autoregressive Model的做法,不过离散版AR中MLP的 f ( x < i ) f(x_{< i}) f(x<i)得到的是概率,而这里得到的是用于计算CDF的对应分布的参数,所以2-D版本的ARFlow需要有的参数有

  1. 用于x1-CDF的5个logits,locs,scales
  2. 一个用于生成 f ( ; x 1 ) f(;x_1) f(;x1)的MLP(in:1( x 1 x_1 x1),out:3(logits,locs,scales)*5(5个Gaussian))
    在这里插入图片描述
    之后 x 1 , x 2 x_1,x_2 x1,x2根据自己分布的参数分别得到 z 1 , z 2 z_1,z_2 z1,z2,并且同时通过log_prob()得到 f ′ f' f
    在这里插入图片描述
    然后就可以通过极大似然估计优化参数

N-D

当输入输出的维度变高时,需要设计好 f f f满足:

  1. 可逆
  2. 便于计算Jacobian行列式
1.ARFlow

在这里插入图片描述
在这里插入图片描述
跟2-D的ARFlow类似,使用PixelCNN的Mask方法得到 f ( ; x < i ) f(;x_{<i}) f(;x<i),得到的是CDF分布的参数,所以高维ARFlow模型的参数只有一个Masked PixelCNN(in:H x W,out:H x W x n_comp*3(logits,locs,scales))在这里插入图片描述
在这里插入图片描述
得到H x W个分布的参数后计算CDF得到z,并通过PDF得到Jacobian的行列式
在这里插入图片描述

2.RealNVP

x x x z z z都是高维时,计算 p ( x ) p(x) p(x)的方法推广为
在这里插入图片描述
具体原因,因为1-D是用面积元推导出来的, p ( x ) d x = p ( z ) d z p(x)dx=p(z)dz p(x)dx=p(z)dz,而N-D则对应的体积元,行列式的几何意义就是空间的拉伸程度。
当维度很高时,出现两个问题:

  1. 如何保证 f f f可逆
  2. 如何高效计算行列式

使用affine flow
z = A − 1 ( x − b ) , x = A z + b z=A^{-1}(x-b),x=Az+b z=A1(xb),x=Az+b
每个flow固定一半x,并且通过固定的一半x生成A和b,并对另一半x做affine
在这里插入图片描述
在这里插入图片描述
生成A的时候使用元素级的方法,保证A为对角阵,这样将两半z合并后对原来的x求Jacobian得到的是一个三角阵,左上块因为固定的x,所以求导为单位矩阵,因此可以很容易的计算得到整个的Jacobian行列式为
在这里插入图片描述
如何选择哪一半固定,哪一半做Affine对结果影响很大,文章给出了棋盘式和channel式两种。棋盘式就是像国际象棋的棋盘中黑白的颜色区分;channel式就是沿着channel分。
在这里插入图片描述

棋盘式类似pixelCNN中的Mask方式,做一个0-1mask;channel直接沿着channel分割两份,因为affine是element-wise,MLP输出2*C x H x W,2代表 g θ g_\theta gθ得到的scale和shift
在这里插入图片描述
整体结构如下,做固定x的和做affinex的是交换的,不然到最后有一边的x是一点没变的。

class WeightNormConv2d(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size, stride=1, padding=0,
                 bias=True):
        super(WeightNormConv2d, self).__init__()
        self.conv = nn.utils.weight_norm(
            nn.Conv2d(in_dim, out_dim, kernel_size,
                      stride=stride, padding=padding, bias=bias))

    def forward(self, x):
        return self.conv(x)

# 取消了affine和running_mean/std的BN
class ActNorm(nn.Module):
    def __init__(self, n_channels):
        super(ActNorm, self).__init__()
        self.log_scale = nn.Parameter(torch.zeros(1, n_channels, 1, 1), requires_grad=True)
        self.shift = nn.Parameter(torch.zeros(1, n_channels, 1, 1), requires_grad=True)
        self.n_channels = n_channels
        self.initialized = False

    def forward(self, x, reverse=False):
        if reverse:
            return (x - self.shift) * torch.exp(-self.log_scale), self.log_scale
        else:
            if not self.initialized:
                self.shift.data = -torch.mean(x, dim=[0, 2, 3], keepdim=True)
                self.log_scale.data = - torch.log(
                    torch.std(x.permute(1, 0, 2, 3).reshape(self.n_channels, -1), dim=1).reshape(1, self.n_channels, 1,
                                                                                                 1))
                self.initialized = True
                result = x * torch.exp(self.log_scale) + self.shift
            return x * torch.exp(self.log_scale) + self.shift, self.log_scale

class ResnetBlock(nn.Module):
    def __init__(self, dim):
        super(ResnetBlock, self).__init__()
        self.block = nn.Sequential(
            WeightNormConv2d(dim, dim, (1, 1), stride=1, padding=0),
            nn.ReLU(),
            WeightNormConv2d(dim, dim, (3, 3), stride=1, padding=1),
            nn.ReLU(),
            WeightNormConv2d(dim, dim, (1, 1), stride=1, padding=0))

    def forward(self, x):
        return x + self.block(x)

# 用于从x生成s,t的MLP
class SimpleResnet(nn.Module):
    def __init__(self, in_channels=3, out_channels=6, n_filters=128, n_blocks=8):
        super(SimpleResnet, self).__init__()
        layers = [WeightNormConv2d(in_channels, n_filters, (3, 3), stride=1, padding=1),
                  nn.ReLU()]
        for _ in range(n_blocks):
            layers.append(ResnetBlock(n_filters))
        layers.append(nn.ReLU())
        layers.append(WeightNormConv2d(n_filters, out_channels, (3, 3), stride=1, padding=1))
        self.resnet = nn.Sequential(*layers)

    def forward(self, x):
        return self.resnet(x)

# 棋盘式
class AffineCheckerboardTransform(nn.Module):
    def __init__(self, type=1.0):
        super(AffineCheckerboardTransform, self).__init__()
        self.mask = self.build_mask(type=type)
        self.scale = nn.Parameter(torch.zeros(1), requires_grad=True)
        self.scale_shift = nn.Parameter(torch.zeros(1), requires_grad=True)
        self.resnet = SimpleResnet()

    def build_mask(self, type=1.0):
        # if type == 1.0, the top left corner will be 1.0
        # if type == 0.0, the top left corner will be 0.0
        mask = np.arange(32).reshape(-1, 1) + np.arange(32)
        # 取模得到0-1mask
        mask = np.mod(type + mask, 2)
        mask = mask.reshape(-1, 1, 32, 32)
        return torch.tensor(mask.astype('float32')).to(device)

    def forward(self, x, reverse=False):
        # returns transform(x), log_det
        batch_size, n_channels, _, _ = x.shape
        mask = self.mask.repeat(batch_size, 1, 1, 1)
        x_ = x * mask

        log_s, t = self.resnet(x_).split(n_channels, dim=1)
        log_s = self.scale * torch.tanh(log_s) + self.scale_shift
        t = t * (1.0 - mask)
        log_s = log_s * (1.0 - mask)

        if reverse:  # inverting the transformation
            x = (x - t) * torch.exp(-log_s)
        else:
            x = x * torch.exp(log_s) + t
        return x, log_s

# channel式
class AffineChannelTransform(nn.Module):
    def __init__(self, modify_top):
        super(AffineChannelTransform, self).__init__()
        self.modify_top = modify_top
        self.scale = nn.Parameter(torch.zeros(1), requires_grad=True)
        self.scale_shift = nn.Parameter(torch.zeros(1), requires_grad=True)
        self.resnet = SimpleResnet(in_channels=6, out_channels=12)

    def forward(self, x, reverse=False):
        n_channels = x.shape[1]
        if self.modify_top:
            on, off = x.split(n_channels // 2, dim=1)
        else:
            off, on = x.split(n_channels // 2, dim=1)
        log_s, t = self.resnet(off).split(n_channels // 2, dim=1)
        log_s = self.scale * torch.tanh(log_s) + self.scale_shift

        if reverse:  # inverting the transformation
            on = (on - t) * torch.exp(-log_s)
        else:
            on = on * torch.exp(log_s) + t

        if self.modify_top:
            return torch.cat([on, off], dim=1), torch.cat([log_s, torch.zeros_like(log_s)], dim=1)
        else:
            return torch.cat([off, on], dim=1), torch.cat([torch.zeros_like(log_s), log_s], dim=1)

class RealNVP(nn.Module):
    def __init__(self):
        super(RealNVP, self).__init__()

        self.prior = torch.distributions.Normal(torch.tensor(0.).to(device), torch.tensor(1.).to(device))
        # 固定的和affine的一直在改变
        self.checker_transforms1 = nn.ModuleList([
            AffineCheckerboardTransform(1.0),
            ActNorm(3),
            AffineCheckerboardTransform(0.0),
            ActNorm(3),
            AffineCheckerboardTransform(1.0),
            ActNorm(3),
            AffineCheckerboardTransform(0.0)
        ])

        self.channel_transforms = nn.ModuleList([
            AffineChannelTransform(True),
            ActNorm(12),
            AffineChannelTransform(False),
            ActNorm(12),
            AffineChannelTransform(True),
        ])

        self.checker_transforms2 = nn.ModuleList([
            AffineCheckerboardTransform(1.0),
            ActNorm(3),
            AffineCheckerboardTransform(0.0),
            ActNorm(3),
            AffineCheckerboardTransform(1.0)
        ])
	# 拓宽channel好进行channel式的flow
    def squeeze(self, x):
        # C x H x W -> 4C x H/2 x W/2
        [B, C, H, W] = list(x.size())
        x = x.reshape(B, C, H // 2, 2, W // 2, 2)
        x = x.permute(0, 1, 3, 5, 2, 4)
        x = x.reshape(B, C * 4, H // 2, W // 2)
        return x

    def undo_squeeze(self, x):
        #  4C x H/2 x W/2  ->  C x H x W
        [B, C, H, W] = list(x.size())
        x = x.reshape(B, C // 4, 2, 2, H, W)
        x = x.permute(0, 1, 4, 2, 5, 3)
        x = x.reshape(B, C // 4, H * 2, W * 2)
        return x

    def g(self, z):
        # z -> x (inverse of f)
        x = z
        for op in reversed(self.checker_transforms2):
            x, _ = op.forward(x, reverse=True)
        x = self.squeeze(x)
        for op in reversed(self.channel_transforms):
            x, _ = op.forward(x, reverse=True)
        x = self.undo_squeeze(x)
        for op in reversed(self.checker_transforms1):
            x, _ = op.forward(x, reverse=True)
        return x

    def f(self, x):
        # maps x -> z, and returns the log determinant (not reduced)
        z, log_det = x, torch.zeros_like(x)
        for op in self.checker_transforms1:
            z, delta_log_det = op.forward(z)
            log_det += delta_log_det
        z, log_det = self.squeeze(z), self.squeeze(log_det)
        for op in self.channel_transforms:
            z, delta_log_det = op.forward(z)
            log_det += delta_log_det
        z, log_det = self.undo_squeeze(z), self.undo_squeeze(log_det)
        for op in self.checker_transforms2:
            z, delta_log_det = op.forward(z)
            log_det += delta_log_det
        return z, log_det

    def log_prob(self, x):
        z, log_det = self.f(x)
        return torch.sum(log_det, [1, 2, 3]) + torch.sum(self.prior.log_prob(z), [1, 2, 3])

    def sample(self, num_samples):
        z = self.prior.sample([num_samples, 3, 32, 32])
        return self.g(z)

Dequantization

谈的比较浅,意思就是像图像这样的离散数据 { 0 , 1 , 2...255 } \{0,1,2...255\} {0,1,2...255},如果用连续的方法去拟合会造成性能下降,所以对原始数据添加一个噪音会提高拟合离散数据的性能。
在这里插入图片描述

其他一些点

  1. 利用Pytorch的Distribution库计算GMM模型,B个数据,C个component,可以通过B x C的均值和方差初始化出来B*C个Normal,每个数据的每个component的 π \pi π通过B x C个logits做softmax(dim=1)得到,然后Normal* π \pi π并sum(dim=1)得到B个GMM。通过这种component分开计算的方式,可以利用pytorch自带的cdf、log_prob等方法。
    在这里插入图片描述
  2. ARFlow计算概率时使用GMM的CDF,如果想要sample时,需要计算GMM的CDF的逆函数。这里使用的方法是,令最后得到的z服从Uniform(0,1)。原来的sample就变成了从均匀分布中采样一个数,然后按CDF求逆得到对应的x。这个过程就跟从GMM中采样得到一个数是一样的,就不用就算GMM的CDF的逆了。而从GMM中采样,可以先根据logits采样,选择以哪个component作为center,然后从选择到的Gaussian采样得到最后的x。
    在这里插入图片描述
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
流模型是一种建立原始分布与简单分布之间映射关系的方法,通过对更简单的分布建模来绕开后验过于复杂无法求解的问题。其中一种流模型叫做Autoregressive Flow (ARFlow),它使用了重要的change of variables theorem来建立映射关系,并通过变化后的目标函数和梯度求解方法进行优化。ARFlow可以通过PixelCNN的Mask方法得到条件密度函数(CDF)分布的参数,并计算CDF得到变量z,并通过PDF计算Jacobian矩阵的行列式。另一种流模型是RealNVP,类似于Autoregressive Model,但不同之处在于RealNVP通过MLP获得用于计算CDF的对应分布的参数,然后根据这些参数计算变量z,并通过log_prob()方法得到概率密度函数。最后,如果处理离散数据,如图像中的像素值,连续方法可能会导致性能下降。因此,在流模型中可以给原始数据添加噪音来提高对离散数据的拟合性能。对于离散数据的建模,可以使用Gaussian Mixture Model (GMM),通过PyTorch的Distribution库计算GMM模型,其中每个数据的每个component的权重通过logits进行softmax计算。在ARFlow中,可以使用GMM的CDF计算概率,如果要进行采样,可以通过计算GMM的CDF的逆函数。采样过程类似于从GMM中采样一个数,可以先根据logits选择Gaussian component,然后从所选的Gaussian中采样得到最终的样本。<span class="em">1</span><span class="em">2</span><span class="em">3</span><span class="em">4</span>

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值