像素级创意:深入浅出PixelCNN图像合成技术

参考
https://arxiv.org/pdf/1601.06759
https://blog.csdn.net/zcyzcyjava/article/details/126559327
需要熟悉熵的一些理论、和极大释然估计等价于最小化交叉熵等知识

1. pixelcnn建模方法

pixelcnn做生成模型的想必都有耳闻。它是一种自回归模型,什么是自回归呢?简单的来说自回归模型意味着模型在生成数据时会条件依赖于之前已经生成的数据部分。我们知道无论是GAN还是VAE最初都在找一个思路,那就是想对 p θ ( x ) p_\theta(x) pθ(x)去建模型,事实上输入x的分布是难以确定的,因此,GAN和VAE都绕过了这条路,通过引入额外的网络,避开直接求解 p θ ( x ) p_\theta(x) pθ(x),比如GAN是引入对抗网络D,VAE是引入编码网络。而pixcelcnn不一样,它算是直接暴力求解 p θ ( x ) p_\theta(x) pθ(x),是像素级别的求解,认为第i个像素是由前i-1个像素决定的,因此释然函数可以写成如下,假设有n*n个像素:
p ( x ) = ∏ i = 1 n 2 p ( x i ∣ x 1 , x 2 , . . . , x i − 1 ) = ∏ i = 1 n 2 p ( x i ∣ X < i ) p(x) = \prod_{i=1}^{n^2}p(x_i|x_1,x_2,...,x_{i-1})= \prod_{i=1}^{n^2}p(x_i|X_{<i}) p(x)=i=1n2p(xix1,x2,...,xi1)=i=1n2p(xiX<i)
我们知道,最大化释然函数参数的求解写成对数如下:
θ = a r g m a x θ ∑ i n 2 l o g f i ( x i ∣ θ ) \theta=\underset{\theta}{argmax}\sum_i^{n^2} logf_i(x_i|\theta) θ=θargmaxin2logfi(xiθ)
其中 f i ( x i ∣ θ ) f_i(x_i|\theta) fi(xiθ)为网络预测的结果,另外,作者假设每一个预测的像素值,为0~255中的一个分类,因此,预测分布等价于one-hot形式的多分类。损失函数进而转化为多分类的交叉熵问题,为: l o s s = E n t r y C r o s s ( y i , x i ) loss=EntryCross(y_i,x_i) loss=EntryCross(yi,xi) 其中 y i y_i yi为预测的像素值, x i x_i xi为输入的像素值。

2. 网络结构

在这里插入图片描述在这里插入图片描述
以上是论文中提及到的网络结构,我们这里只看CNN结构,不看RNN结构。看这个网络结构其实很简单,也就是77卷积(maskA)+多个卷积残差结构(33 maskB)+2个1*1卷积(maskB)。但是,这里面有两个不是我们认识的那个CNN,一个是maskA一个是maskB。
maskA:

在这里插入图片描述
maskB:
在这里插入图片描述
简单来说,maskA是不包含中心元素的上半部分卷积,maskB是包含中心元素的上半部分卷积。这样做的目的是什么,具体详细原由可看原论文,意思是这样做生成的feature map每一个像素的感受野只会看到它上半部分的像素,不包含自身像素,这也满足之前pixcelcnn的建模:每一个像素都是由其前面i-1个像素决定的。因此pixelcnn网络结构是确定的,跟其建模是一致的。

3 MINIST测试pixelcnn

3.1 maskcnn构建

pixelcnn 关键是mask卷积的构造,我们看一下具体如何实现:

class MaskConv2d(nn.Module):
    """ 通过使用 mask 来构建 maskA和maskB Conv2d,方法是通过mask乘上卷积的权重"""
    def __init__(self, conv_type, *args, **kwargs):
        """
        :param conv_type: maskA还是maskB
        :param args:
        :param kwargs:
        """
        super(MaskConv2d, self).__init__()
        self.conv = nn.Conv2d(*args, **kwargs)
        k_h, k_w = self.conv.weight.shape[-2:]
        mask = torch.zeros((k_h, k_w), dtype=torch.float32)

        # maskA
        mask[0:k_h//2] = 1
        mask[k_h//2, 0:k_w//2] = 1

        # maskB
        if conv_type == 'B':
            mask[k_h//2, k_w//2] = 1

        mask = mask.reshape((1,1,k_h, k_w))
        self.register_buffer('mask', mask, False)

    def forward(self, x):
        self.conv.weight.data *= self.mask
        conv_res = self.conv(x)
        return conv_res

通过以上代码,我们可以看出,先构造一个nn.Conv2d,然后构建maskA和maskB的只有0和1值的大小为kernel size的矩阵,然后乘上nn.Conv2d的权重,来实现两种Mask CNN,非常简单。

3.2 整个网络结构

整个网络结构是按照论文里面构建的,不在多说,直接看代码:

import torch
import torch.nn as nn


class MaskConv2d(nn.Module):
    """ 通过使用 mask 来构建 maskA和maskB Conv2d,方法是通过mask乘上卷积的权重"""
    def __init__(self, conv_type, *args, **kwargs):
        """
        :param conv_type: maskA还是maskB
        :param args:
        :param kwargs:
        """
        super(MaskConv2d, self).__init__()
        self.conv = nn.Conv2d(*args, **kwargs)
        k_h, k_w = self.conv.weight.shape[-2:]
        mask = torch.zeros((k_h, k_w), dtype=torch.float32)

        # maskA
        mask[0:k_h//2] = 1
        mask[k_h//2, 0:k_w//2] = 1

        # maskB
        if conv_type == 'B':
            mask[k_h//2, k_w//2] = 1

        mask = mask.reshape((1,1,k_h, k_w))
        self.register_buffer('mask', mask, False)

    def forward(self, x):
        self.conv.weight.data *= self.mask
        conv_res = self.conv(x)
        return conv_res


class ResidualBlock(nn.Module):
    """ 残差块 """

    def __init__(self, h, bn=True):
        super(ResidualBlock, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(2*h, h, 1)
        self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()

        self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()

        self.conv3 = nn.Conv2d(h, 2*h, 1)
        self.bn3 = nn.BatchNorm2d(2*h) if bn else nn.Identity()

    def forward(self, x):

        y = self.relu(x)
        y = self.conv1(y)
        y = self.bn1(y)

        y = self.relu(y)
        y = self.conv2(y)
        y = self.bn2(y)

        y = self.relu(y)
        y = self.conv3(y)
        y = self.bn3(y)
        return x + y


class PixelCNN(nn.Module):
    def __init__(self, n_block=15, h=128, bn=True, color_level=256):
        super(PixelCNN, self).__init__()

        # 7*7 conv
        self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)
        self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()

        # residual
        self.residual_blocks = nn.ModuleList()
        for _ in range(n_block):
            self.residual_blocks.append(ResidualBlock(h, bn))
        self.relu = nn.ReLU()

        # 2个1*1 maskB,
        self.head = nn.Sequential(
            MaskConv2d('B', 2*h, h, 1),
            nn.ReLU(),
            MaskConv2d('B', h, h, 1),
            nn.ReLU(),
            nn.Conv2d(h, color_level, 1)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)

        for block in self.residual_blocks:
            x = block(x)
        x = self.relu(x)

        x = self.head(x)
        return x

if __name__ == '__main__':
    from torchinfo import summary
    pixelcnn = PixelCNN()

    summary(pixelcnn, input_size=(1, 1, 28, 28), depth=2)

打印看一下网络结构,如果想看详细的层级结构,把summary中的depth改成3

PixelCNN [1, 256, 28, 28] –
├─MaskConv2d: 1-1 [1, 256, 28, 28] –
│ └─Conv2d: 2-1 [1, 256, 28, 28] 12,800
├─BatchNorm2d: 1-2 [1, 256, 28, 28] 512
├─ModuleList: 1-3 – –
│ └─ResidualBlock: 2-2 [1, 256, 28, 28] 214,528
│ └─ResidualBlock: 2-3 [1, 256, 28, 28] 214,528
│ └─ResidualBlock: 2-4 [1, 256, 28, 28] 214,528
│ └─ResidualBlock: 2-5 [1, 256, 28, 28] 214,528
│ └─ResidualBlock: 2-6 [1, 256, 28, 28] 214,528
│ └─ResidualBlock: 2-7 [1, 256, 28, 28] 214,528
│ └─ResidualBlock: 2-8 [1, 256, 28, 28] 214,528
│ └─ResidualBlock: 2-9 [1, 256, 28, 28] 214,528
│ └─ResidualBlock: 2-10 [1, 256, 28, 28] 214,528
│ └─ResidualBlock: 2-11 [1, 256, 28, 28] 214,528
│ └─ResidualBlock: 2-12 [1, 256, 28, 28] 214,528
│ └─ResidualBlock: 2-13 [1, 256, 28, 28] 214,528
│ └─ResidualBlock: 2-14 [1, 256, 28, 28] 214,528
│ └─ResidualBlock: 2-15 [1, 256, 28, 28] 214,528
│ └─ResidualBlock: 2-16 [1, 256, 28, 28] 214,528
├─ReLU: 1-4 [1, 256, 28, 28] –
├─Sequential: 1-5 [1, 256, 28, 28] –
│ └─MaskConv2d: 2-17 [1, 128, 28, 28] 32,896
│ └─ReLU: 2-18 [1, 128, 28, 28] –
│ └─MaskConv2d: 2-19 [1, 128, 28, 28] 16,512
│ └─ReLU: 2-20 [1, 128, 28, 28] –
│ └─Conv2d: 2-21 [1, 256, 28, 28] 33,024

3.2 训练

import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F


class MaskConv2d(nn.Module):
    """ 通过使用 mask 来构建 maskA和maskB Conv2d,方法是通过mask乘上卷积的权重"""
    def __init__(self, conv_type, *args, **kwargs):
        """
        :param conv_type: maskA还是maskB
        :param args:
        :param kwargs:
        """
        super(MaskConv2d, self).__init__()
        self.conv = nn.Conv2d(*args, **kwargs)
        k_h, k_w = self.conv.weight.shape[-2:]
        mask = torch.zeros((k_h, k_w), dtype=torch.float32)

        # maskA
        mask[0:k_h//2] = 1
        mask[k_h//2, 0:k_w//2] = 1

        # maskB
        if conv_type == 'B':
            mask[k_h//2, k_w//2] = 1

        mask = mask.reshape((1,1,k_h, k_w))
        self.register_buffer('mask', mask, False)

    def forward(self, x):
        self.conv.weight.data *= self.mask
        conv_res = self.conv(x)
        return conv_res


class ResidualBlock(nn.Module):
    """ 残差块 """

    def __init__(self, h, bn=True):
        super(ResidualBlock, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(2*h, h, 1)
        self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()

        self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()

        self.conv3 = nn.Conv2d(h, 2*h, 1)
        self.bn3 = nn.BatchNorm2d(2*h) if bn else nn.Identity()

    def forward(self, x):

        y = self.relu(x)
        y = self.conv1(y)
        y = self.bn1(y)

        y = self.relu(y)
        y = self.conv2(y)
        y = self.bn2(y)

        y = self.relu(y)
        y = self.conv3(y)
        y = self.bn3(y)
        return x + y


class PixelCNN(nn.Module):
    def __init__(self, n_block=15, h=128, bn=True, color_level=256):
        super(PixelCNN, self).__init__()

        # 7*7 conv
        self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)
        self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()

        # residual
        self.residual_blocks = nn.ModuleList()
        for _ in range(n_block):
            self.residual_blocks.append(ResidualBlock(h, bn))
        self.relu = nn.ReLU()

        # 2个1*1 maskB,
        self.head = nn.Sequential(
            MaskConv2d('B', 2*h, h, 1),
            nn.ReLU(),
            MaskConv2d('B', h, h, 1),
            nn.ReLU(),
            nn.Conv2d(h, color_level, 1)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)

        for block in self.residual_blocks:
            x = block(x)
        x = self.relu(x)

        x = self.head(x)
        return x


def train(num_epochs, batch, gpuid):
    device = torch.device(f"cuda:{gpuid}")
    trian_data = datasets.MNIST(root='data', train=True, download=True, transform=transforms.ToTensor())
    train_dataloader = DataLoader(trian_data, batch_size=batch, shuffle=True)
    model = PixelCNN()
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        for x, _ in train_dataloader:
            x = x.to(device)
            label = torch.ceil(x*255).long()
            label = label.squeeze(1)
            loss = loss_fn(model(x), label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(f"epoch:{epoch}, loss:{loss.item()}")

        sample(model, device, 64)


def sample(model, device, n_sample=64):
    model.eval()
    C, H, W = (1, 28, 28)
    x = torch.zeros((n_sample, C, H, W)).to(device)
    with torch.no_grad():
        for i in range(H):
            for j in range(W):
                output = model(x)
                prob_dist = F.softmax(output[:,:,i,j], dim=1).data
                pixel = torch.multinomial(prob_dist, 1).float() / 255
                x[:,:,i,j] = pixel

    # Saving images row wise
    torchvision.utils.save_image(x, 'imgs.png', nrow=8, padding=0)


if __name__ == '__main__':

    train(100, 128, 0)

上面是整个完整代码,训练很简单,就是比较常规,看一眼都能明白。关键是采样算法,采样算法的过程是,先初始一个全0的图像,然后得到第一个像素,然后把第一个像素赋值给输入图像的第一个元素,以此类推得到最终的生成图像。关键步骤为下面三步:

  • prob_dist = F.softmax(output[:,:,i,j], dim=1).data 得到当前像素的概率分布
  • pixel = torch.multinomial(prob_dist, 1).float() / 255 从这个概率分布中随机采样一个值对应的索引,也就是一个0-255的像素值,因为softmax返回的是0-255(其实是/255归一化后)对应的概率,因此取索引值即为推理的像素值,然后/255归一化成minist输入
  • x[:,:,i,j] = pixel,这个就是把推理的当前像素值赋值给输入x,去推理下一帧输入

结果:质量上还是比GAN差。以下是50个epoch的结果。
在这里插入图片描述

4 缺点

pixelCNN作为一类基于卷积神经网络的生成模型,在图像生成领域有着其独特之处,但也存在一些缺点,主要包括但不限于:

  • 生成速度慢:PixelCNN的核心缺点之一是生成速度缓慢。因为它采用自回归的方式生成图像,即模型需要依次生成每一个像素,每个像素的生成都依赖于之前的所有像素。这种方式导致在生成高分辨率图像时,所需的计算时间和步骤显著增加。

  • 训练时间长:据报道,即便是增强版的PixelCNN如PixelCNN++,也需要在强大的硬件配置(如8块Titan GPU)上训练多天才能收敛,而且这还仅是在处理相对较小的数据集(如CIFAR)时的情况。训练时间长不仅增加了资源消耗,也影响了研究与应用的效率。

  • 采样效率低:由于逐像素生成的特性,PixelCNN在采样过程中无法并行化,这意味着即使在现代GPU上也无法有效利用硬件加速带来的并行计算优势,进一步降低了生成效率。

  • 长程依赖建模能力有限:尽管PixelCNN使用了卷积层来捕捉局部特征,但自回归的生成顺序限制了模型对图像中远距离像素间依赖关系的建模能力,可能影响生成图像的全局一致性与细节丰富度。

  • 内存使用:逐像素生成的过程中需要存储中间状态以供后续像素生成使用,这可能导致较高的内存使用,尤其是在处理大尺寸图像时。

  • 样本质量:相比于同期的一些生成对抗网络(GANs)模型,PixelCNN生成的样本质量可能略逊一筹,尤其是在生成高保真度和视觉复杂度高的图像方面。

尽管有这些缺点,PixelCNN及其后续变体通过引入如门控机制、更高效的网络结构设计等方法,在一定程度上改善了这些问题,并在图像生成任务中保持着一定的竞争力。

  • 17
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

idealmu

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值