【GAN】CycleGAN学习--官方源码特殊地方

在cycle_gan_model.py的146行,如下代码中有self.fake_A_pool.query(self.fake_A)。

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

那么这个self.fake_A_pool.query(self.fake_A)究竟干了什么事情呢?

这个函数的具体实现在image_pool.py中,接下来我们一句句分析代码,看他具体想干嘛。

首先

ImagePool的初始化,将pool_size=50,其中self.num_imgs = 0和self.images = []是新增图片计数器和保存容器。

class ImagePool():
    """This class implements an image buffer that stores previously generated images.

    This buffer enables us to update discriminators using a history of generated images
    rather than the ones produced by the latest generators.
    """

    def __init__(self, pool_size):
        """Initialize the ImagePool class

        Parameters:
            pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
        """
        self.pool_size = pool_size
        if self.pool_size > 0:  # create an empty pool
            self.num_imgs = 0
            self.images = []

接着

使用到了self.fake_A_pool.query(self.fake_A),可以看到每一次输入的图像都是self.fake_A,self.fake_A就是一个形状为(1,3,256,256)的生成图像,然后判别是否小于50,由于self.num_imgs的初始化为0,所以就进入该条件,self.num_imgs+1=1,然后self.images中添加了一张生成图片,返回图像为输入的self.fake_A。经过50次迭代后,self.images中就有了50张图片,进入到下面的else中,有50%的几率选中新输入的self.fake_A,还有50%的几率从self.images的50张切片中选出一张返回,并将新的一张图片给添加进去,使得self.images中的数量始终保持50。

    def query(self, images):
        """Return an image from the pool.

        Parameters:
            images: the latest generated images from the generator

        Returns images from the buffer.

        By 50/100, the buffer will return input images.
        By 50/100, the buffer will return images previously stored in the buffer,
        and insert the current images to the buffer.
        """
        if self.pool_size == 0:  # if the buffer size is 0, do nothing
            return images
        return_images = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)
            if self.num_imgs < self.pool_size:   # if the buffer is not full; keep inserting current images to the buffer
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:  # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
                    random_id = random.randint(0, self.pool_size - 1)  # randint is inclusive
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:       # by another 50% chance, the buffer will return the current image
                    return_images.append(image)
        return_images = torch.cat(return_images, 0)   # collect all the images and return
        return return_images

为什么需要这么做?

首先得了解,这个处理过程发生在训练判别器的阶段,判断器也就是分类器,为了能使这个分类器能力足够强,那么不管是你新生成的图片还是以前生成的图片,我都能判别的出来,这样的分类器才是好分类器。

那么这时就需要以前生成的部分fake图片和现在生成的fake图片,为了满足硬件的需要,就只能设置小一点,设个50就行了。

  • 7
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值