在cycle_gan_model.py的146行,如下代码中有self.fake_A_pool.query(self.fake_A)。
def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
那么这个self.fake_A_pool.query(self.fake_A)究竟干了什么事情呢?
这个函数的具体实现在image_pool.py中,接下来我们一句句分析代码,看他具体想干嘛。
首先
ImagePool的初始化,将pool_size=50,其中self.num_imgs = 0和self.images = []是新增图片计数器和保存容器。
class ImagePool():
"""This class implements an image buffer that stores previously generated images.
This buffer enables us to update discriminators using a history of generated images
rather than the ones produced by the latest generators.
"""
def __init__(self, pool_size):
"""Initialize the ImagePool class
Parameters:
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
"""
self.pool_size = pool_size
if self.pool_size > 0: # create an empty pool
self.num_imgs = 0
self.images = []
接着
使用到了self.fake_A_pool.query(self.fake_A),可以看到每一次输入的图像都是self.fake_A,self.fake_A就是一个形状为(1,3,256,256)的生成图像,然后判别是否小于50,由于self.num_imgs的初始化为0,所以就进入该条件,self.num_imgs+1=1,然后self.images中添加了一张生成图片,返回图像为输入的self.fake_A。经过50次迭代后,self.images中就有了50张图片,进入到下面的else中,有50%的几率选中新输入的self.fake_A,还有50%的几率从self.images的50张切片中选出一张返回,并将新的一张图片给添加进去,使得self.images中的数量始终保持50。
def query(self, images):
"""Return an image from the pool.
Parameters:
images: the latest generated images from the generator
Returns images from the buffer.
By 50/100, the buffer will return input images.
By 50/100, the buffer will return images previously stored in the buffer,
and insert the current images to the buffer.
"""
if self.pool_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else: # by another 50% chance, the buffer will return the current image
return_images.append(image)
return_images = torch.cat(return_images, 0) # collect all the images and return
return return_images
为什么需要这么做?
首先得了解,这个处理过程发生在训练判别器的阶段,判断器也就是分类器,为了能使这个分类器能力足够强,那么不管是你新生成的图片还是以前生成的图片,我都能判别的出来,这样的分类器才是好分类器。
那么这时就需要以前生成的部分fake图片和现在生成的fake图片,为了满足硬件的需要,就只能设置小一点,设个50就行了。