(八)StackGAN-v1 论文笔记与实战

一、背景及解决的问题

       由文本生成高质量的图像是计算机视觉中一个极具挑战性的问题,然而,训练GAN从文本描述中生成高分辨率的真实感图像是非常困难的。简单地在最先进的GAN模型中添加更多的上采样层来生成高分辨率(例如256×256)图像通常会导致训练不稳定并产生无意义的输出。通过GAN生成高分辨率图像的主要困难在于自然图像分布和隐含模型分布在高维像素空间中可能不重叠,随着图像分辨率的提高,这个问题更加严重。
       为了解决这个问题,stackGAN-v1提出了分两个阶段生成图像的模型。

第一阶段:在给定文本描述的基础上,绘制物体的原始形状和基本颜色,并从随机噪声矢量中绘制背景布局,生成低分辨率图像。
第二阶段:它纠正了在第一阶段的低分辨率图像中存在缺陷,并通过再次读取文本描述来完成对象的细节,从而生成高分辨率照片真实感图像。

       通过对第一阶段结果和文本的再约束,第二阶段GAN学习捕捉第一阶段GAN省略的文本信息,并为对象绘制更多细节。由粗略排列的低分辨率图像生成的模型分布与真实图像分布相交的概率更大。这就是第二阶段GAN能够产生更好的高分辨率图像的根本原因。

二、Conditioning Augmentation(条件增强)

网络结构
       如上图所示,文本描述 t 首先由编码器编码,产生文本嵌入 φ t \varphi _t φt。在之前的工作中,文本嵌入被非线性转换以生成作为生成器输入的条件潜在变量。然而,文本嵌入的潜在空间通常是高维的(>100维)。在数据量有限的情况下,这通常会导致潜在数据流形的不连续性,这是不可取的。为了解决这个问题,引入了条件增强技术来产生额外的条件变量 c ^ \hat c c^, 这里的条件变量不是固定的,而是从独立的高斯分布 N ( μ ( φ t ) , ∑ ( φ t ) ) N (\mu(\varphi_t), \sum(\varphi_t)) N(μ(φt),(φt)) 中随机抽取隐变量 c ^ \hat c c^ ,其中平均值 μ ( φ t ) \mu(\varphi_t) μ(φt) 和对角协方差矩阵 ∑ ( φ t ) ) \sum(\varphi_t)) (φt)) 是文本embedding φ t \varphi _t φt 的函数。条件增强在少量图像-文本对的情况下产生更多的训练数据对,并有助于对条件流形的小扰动并且具有鲁棒性。为了进一步加强条件流形上的光滑性,避免过度拟合,作者在训练过程中为生成器的目标添加了以下正则化项:

                                                                       D K L ( N ( μ ( φ t ) , ∑ ( φ t ) ) ) ∣ ∣ N ( 0 , I ) D_{KL}(N(\mu(\varphi_t), \sum(\varphi_t))) || N(0, I) DKL(N(μ(φt),(φt)))N(0,I)

       不难看出,该正则项就是标准高斯分布与条件高斯分布之间的KL距离, 添加这种具有随机性的正则项,有利于模型对于同一句话可以产生满足条件约束但形状或姿势不同的图像,实现生成图像的多样性。

三、实现过程
(一)Stage-I GAN

第一个阶段是交替地最大化等式 (3) 中的 L D 0 L_{D_0} LD0 和最小化等式 (4) 中的 L G 0 L_{G_0} LG0
在这里插入图片描述
其中, I 0 I_0 I0 表示真实图像, λ \lambda λ = 1

模型结构:

       对于生成器 G 0 G_0 G0 , 为了获得条件变量 c ^ \hat c c^ ,必须先把 φ t \varphi_t φt 输入到一个完全连接层中以产生 μ 0 \mu_0 μ0 σ 0 \sigma_0 σ0(其中, σ 0 \sigma_0 σ0 ∑ 0 \sum_0 0 的对角线的值),从而可以构成前面提到的高斯分布: N ( μ 0 ( φ t ) , ∑ 0 ( φ t ) ) N (\mu_0(\varphi_t), \sum_0(\varphi_t)) N(μ0(φt),0(φt)) ,然后从该高斯分布中随机取样 c ^ 0 \hat c_0 c^0,最终的条件向量 c ^ 0 \hat c_0 c^0 的计算公式为: c ^ 0 \hat c_0 c^0 = μ 0 \mu_0 μ0 + σ 0 ⊙ ϵ \sigma_0 ⊙\epsilon σ0ϵ,其中, ⊙ ⊙ 表示逐元素乘积, ϵ \epsilon ϵ 表示从 N ( 0 , I ) N(0, I) N(0,I) 中随机抽取的一个随机数。然后, c ^ 0 \hat c_0 c^0 和一个噪声向量 z 通过一系列的up-sampling blocks Concatenate 起来。
       对于判别器 D 0 D_0 D0, the text embedding φ t \varphi_t φt 首先通过完全连接层压缩到 N d N_d Nd 维,然后再复制形成 M d ∗ M d ∗ N d M_d * M_d * N_d MdMdNd 的张量,与此同时,将图像送入一系列的下采样块,直到它是 M d ∗ M d M_d * M_d MdMd 的空间维度,然后将图像的filter map 与 text embedding φ t \varphi_t φt 的通道维度连接起来,得到的张量被进一步送入1×1卷积层,以共同学习图像和文本的特征。最后,使用一个具有一个节点的完全连接层来生成决策得分。

(二)Stage-II GAN

       根据生成的低分辨率的结果 s 0 s_0 s0 = G 0 ( z , c ^ 0 ) G_0(z,\hat c_0) G0(z,c^0) 和高斯条件变量 c ^ 0 \hat c_0 c^0,判别器 D D D G G G 在第二阶段交替地最大化 L D L_D LD 和最小化 L G L_G LG
在这里插入图片描述
       也即把第一阶段的 z 换成 s 0 s_0 s0 即可。本阶段使用的 c ^ 0 \hat c_0 c^0 和第一阶段的 c ^ 0 \hat c_0 c^0 共享相同的预训练text encoder, generating the same text embedding φ t \varphi_t φt,然而,Stage I 和Stage II 条件增强具有不同的全连接层,用于生成不同的平均值和标准差。以这种方式,阶段II GAN学习捕捉由阶段I GAN忽略的text embedding中的有用信息。

模型结构:

       我们把 Stage-II 的generator 设计成带有残差块的 encoder-decoder network,与前一阶段类似,the text embeeding φ t \varphi_t φt 是用来生成 N g N_g Ng 维的条件向量 c ^ \hat c c^ ,然后复制为 M g ∗ M g ∗ N g M_g * M_g * N_g MgMgNg 的tensor, 与此同时,Stage-I 产生的 s 0 s_0 s0 被送入多个下采样块,直到变为 M g ∗ M g M_g * M_g MgMg,然后,The image features and the text features are concatenated along the channel dimension.
       对于判别器,它的结构类似于第一阶段的判别器,只需要额外的下采样块,因为该级的图像尺寸较大。为了显式地执行GAN来学习图像和条件文本之间更好的对齐,我们采用了Reed等人提出的匹配感知鉴别器,而不是使用vanilla鉴别器。对于这两个阶段来说,在训练过程中,鉴别器将真实图像及其对应的文本描述作为正样本对,而负样本对由两组组成。第一种是嵌入不匹配文本的真实图像,第二种是具有相应文本嵌入的合成图像。

四、实现细节

       上采样块包括最近邻上采样和3×3步长为1的卷积。批量标准化和ReLU激活应用在除最后一个之外的每个卷积之后。剩余的块由3×3步长为1的卷积,批量标准化和ReLU组成。在128×128 Stack-GAN模型中使用两个残差块,而在256×256模型中使用四个残余块。 下采样块由4×4步长为2的卷积,批量归一化和LeakyReLU组成,除了第一个没有批量标准化,其他都使用了批量标准化。默认情况下, N g N_g Ng = 128, N z N_z Nz = 100, M g M_g Mg = 16, M d M_d Md = 4, N d N_d Nd = 128 ,对于训练的过程,作者首先用六百个周期迭代训练第一阶段GAN的 D 0 D_0 D0 G 0 G_0 G0 。然后,用另外六百个周期迭代地训练第二阶段 GAN 的 D 和 G ,来修正第一层GAN的结果。所有网络都使用 Adam 进行训练,求解器的批量大小为64,初始学习率为0.0002。学习率每100个周期衰减到先前值的1/2。

五、完整代码

参数初始化:

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import os
import torchvision
parser = argparse.ArgumentParser()
parser.add_argument('--text_dimension', type = int, default=1024, help='the dimension of text embedding')
parser.add_argument('--condition_dim', type = int, default=128, help='the dimension of condition')
parser.add_argument('--gf_dim', type = int, default=128)
parser.add_argument('--df_dim', type = int, default=64)
parser.add_argument('--z_dim', type = int, default=100)
parser.add_argument('--residual_num', type = int, default = 4, help='the number of residual block')
parser.add_argument('--lr', type = float, default=0.0002)
parser.add_argument('--b1', type = float, default=0.5)
parser.add_argument('--b2', type = float, default=0.999)
parser.add_argument('--dataDir', default='data')
parser.add_argument('--s1_start_epoch', type = int, default=1) # 按情况更改
parser.add_argument('--s2_start_epoch', type = int, default=20)# 情况更改
parser.add_argument('--num_epochs', type = int, default=200)
parser.add_argument('--batch_size', type = int, default=64)
opt = parser.parse_args(args = [])
random.seed(999)
torch.manual_seed(999)

网络结构:

'''网络模型'''
def conv3x3(in_fea, out_fea, stride = 1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_fea, out_fea, kernel_size=3, stride=stride,padding=1,bias=False)

# Upsale the spatial size by a factor of 2
def upBlock(in_fea, out_fea):
    block = nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),
        conv3x3(in_fea,out_fea),
        nn.BatchNorm2d(out_fea),
        nn.ReLU(True)
    )
    return block

# residual block
class ResBlock(nn.Module):
    def __init__(self, channel_num):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(
            conv3x3(channel_num,channel_num),
            nn.BatchNorm2d(channel_num),
            nn.ReLU(True),
            conv3x3(channel_num,channel_num),
            nn.BatchNorm2d(channel_num)
        )
        self.relu = nn.ReLU(True)
    def forward(self, x):
        residual = x
        out = self.block(x)
        out = out + residual
        out = self.relu(out)
        return out

class CA_NET(nn.Module):
    # the Condition Augumentation(条件增强)
    # some code is modified from vae examples
    # (https://github.com/pytorch/examples/blob/master/vae/main.py)
    def __init__(self):
        super(CA_NET, self).__init__()
        self.t_dim = opt.text_dimension
        self.c_dim = opt.condition_dim
        self.fc = nn.Linear(self.t_dim, self.c_dim*2, bias=False)
        self.relu = nn.ReLU(True)
        
    def encode(self, text_embedding):
        x = self.relu(self.fc(text_embedding))
        mu = x[:, :self.c_dim] # 前面c_dim用作均值输出
        log_var = x[:, self.c_dim:] # 后面c_dim用作log方差输出
        return mu, log_var
    def reparameter(self, mu, log_var):
        std = torch.exp(0.5*log_var)
#         eps = torch.randn(*std.shape, device = device) # !!!!!!!!!!!!!!!!!!!!!
        eps = torch.cuda.FloatTensor(std.size()).normal_()
        return mu + torch.mul(std, eps)
    def forward(self, text_embedding):
        mu, log_var = self.encode(text_embedding)
        c_code = self.reparameter(mu, log_var)
        return c_code, mu, log_var

class D_GET_LOGITS(nn.Module):
    '''这部分的作用是:真正的最后的输出,输入是两部分,一部分是Condition_dim = (b, 128, 4, 4),
       另一部分是第一阶段生成的image,reshape成(b, 512, 4, 4), 然后把这两部分在第二个维度上cancate起来,作为输入'''
    def __init__(self, ndf, nc):
        super(D_GET_LOGITS, self).__init__()
        self.df_dim = ndf # 默认 64
        self.c_dim = nc # 默认 128
        self.out_logits = nn.Sequential( # 输入 [b, 512+128, 4, 4]
            conv3x3(ndf * 8 + nc, ndf * 8), # --> [b, 512, 4, 4]
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), # --> [b, 1, 1, 1] 
            nn.Sigmoid()
        )
    def forward(self, img_code, c_code):
        # conditioning output
        c_code = c_code.view(-1, self.c_dim, 1, 1)
        c_code = c_code.repeat(1, 1, 4, 4)
        # state size (ngf+egf) x 4 x 4
        img_c_code = torch.cat((img_code, c_code), dim = 1)
        output = self.out_logits(img_c_code)
        return output.view(-1)

# ############# Networks for stageI GAN #############
class STAGE1_G(nn.Module):
    def __init__(self):
        super(STAGE1_G, self).__init__()
        self.gf_dim = opt.gf_dim # 默认 128
        self.c_dim = opt.condition_dim # 默认 128
        self.z_dim = opt.z_dim # 默认 100
        self.define_module()
        
    def define_module(self):
        ninput = self.z_dim + self.c_dim
        ngf = self.gf_dim
        # TEXT.DIMENSION -> GAN.CONDITION_DIM
        self.ca_net = CA_NET()
        
        # -> ngf x 4 x 4
        self.fc = nn.Sequential(
            nn.Linear(ninput, ngf * 4 * 4, bias = False),
            nn.BatchNorm1d(ngf * 4 * 4),
            nn.ReLU(True)
        )
        # ngf x 4 x 4 -> ngf/2 x 8 x 8
        self.upsample1 = upBlock(ngf, ngf//2)
        # -> ngf/4 x 16 x 16
        self.upsample2 = upBlock(ngf//2, ngf//4)
        # -> ngf/8 x 32 x 32
        self.upsample3 = upBlock(ngf//4, ngf//8)
        # -> ngf/16 x 64 x 64
        self.upsample4 = upBlock(ngf//8, ngf//16)
        # -> 3 x 64 x 64
        self.img = nn.Sequential(
            conv3x3(ngf//16, 3),
            nn.Tanh()
        )
    def forward(self, text_embedding, noise):
        c_code, mu, log_var = self.ca_net(text_embedding)
        z_c_code = torch.cat((noise, c_code), dim = 1)
        img_code = self.fc(z_c_code)
        img_code = img_code.view(-1, self.gf_dim, 4, 4) # !!!!千万不要忘记这一步
        fake_img = self.upsample1(img_code)
        fake_img = self.upsample2(fake_img)
        fake_img = self.upsample3(fake_img)
        fake_img = self.upsample4(fake_img)
        # state size 3 x 64 x 64
        fake_img = self.img(fake_img)
        return _, fake_img, mu, log_var
class STAGE1_D(nn.Module):
    def __init__(self):
        super(STAGE1_D,self).__init__()
        self.df_dim = opt.df_dim # 默认 64
        self.c_dim = opt.condition_dim # 默认 128
        self.define_module()
        
    def define_module(self):
        ndf, nc = self.df_dim, self.c_dim
        # 这部分把image变成-->[b, 512, 4, 4]
        self.encode_img = nn.Sequential( # 输入shape [b, 3, 64, 64]
            nn.Conv2d(3, ndf, 4, 2, 1, bias = False), # --> [b, 64, 32, 32]
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias = False), # --> [b, 128, 16, 16]
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias = False), # --> [b, 256, 8, 8]
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias = False), # --> [b, 512, 4, 4]
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, True), 
        )
        self.get_cont_logits = D_GET_LOGITS(ndf, nc)
    
    def forward(self, img):
        img_code = self.encode_img(img)
        
        return img_code
    
# ############# Networks for stageII GAN #############
class STAGE2_G(nn.Module):
    def __init__(self, STAGE1_G):
        super(STAGE2_G, self).__init__()
        self.gf_dim = opt.gf_dim
        self.c_dim = opt.condition_dim
        self.STAGE1_G = STAGE1_G
        # fix parameters of stageI GAN
        for param in self.STAGE1_G.parameters():
            param.requires_grad = False
        self.define_module()
        
    def _make_residual_net(self, block, channel_num):
        layers = []
        for i in range(opt.residual_num):
            layers.append(block(channel_num))
        return nn.Sequential(*layers)
            
    def define_module(self):
        ngf = self.gf_dim # 默认 128
        # TEXT.DIMENSION -> GAN.CONDITION_DIM
        self.ca_net = CA_NET()
        # 这部分的作用就是把(3, 64, 64)的图片 --> 4*ngf x 16 x 16 (512, 16, 16)
        self.encoder = nn.Sequential(
            conv3x3(3, ngf), # --> [128, 64, 64]
            nn.ReLU(True),
            nn.Conv2d(ngf, ngf * 2, 4, 2, 1, bias = False), # --> [256, 32, 32]
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.Conv2d(ngf * 2, ngf * 4, 4, 2, 1, bias = False), # --> [512, 16, 16]
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
        )
        self.hr_joint = nn.Sequential(
            conv3x3(self.c_dim + ngf * 4, ngf * 4),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
        )
        self.residual = self._make_residual_net(ResBlock, ngf * 4)
        # 输入 [b, 512, 16, 16]
        self.upsample1 = upBlock(ngf * 4, ngf * 2) # --> [b, 256, 32, 32]
        self.upsample2 = upBlock(ngf * 2, ngf) # --> [b, 128, 64, 64]
        self.upsample3 = upBlock(ngf, ngf//2) # --> [b, 64, 128, 128]
        self.upsample4 = upBlock(ngf//2, ngf//4) # --> [b, 32, 256, 256]
        self.img = nn.Sequential(
            conv3x3(ngf//4, 3), # --> [b, 3, 256, 256]
            nn.Tanh()
        )
    
    def forward(self, text_embedding, noise):
        _, stage1_img, _, _ = self.STAGE1_G(text_embedding, noise)
        
        stage1_img = stage1_img.detach()
        encode_img = self.encoder(stage1_img)
        
        c_code, mu, log_var = self.ca_net(text_embedding)
        c_code = c_code.view(-1, self.c_dim, 1, 1)
        c_code = c_code.repeat(1, 1, 16, 16)
        img_c_code = torch.cat((encode_img, c_code), dim = 1)
        
        h_code = self.hr_joint(img_c_code)
        
        h_code = self.residual(h_code)
        
        h_code = self.upsample1(h_code)
        h_code = self.upsample2(h_code)
        h_code = self.upsample3(h_code)
        h_code = self.upsample4(h_code)
        
        fake_img = self.img(h_code)
        
        return stage1_img, fake_img, mu, log_var
    
class STAGE2_D(nn.Module):
    def __init__(self):
        super(STAGE2_D, self).__init__()
        self.df_dim = opt.df_dim # 默认 64
        self.c_dim = opt.condition_dim # 默认 128
        self.define_module()
        
    def define_module(self):
        ndf, nc = self.df_dim, self.c_dim
        self.encode_img = nn.Sequential(
            # 输入 [b, 3, 256, 256]
            nn.Conv2d(3, ndf, 4, 2, 1, bias = False), # --> [b, 64, 128128]
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias = False), # --> [b, 128, 64, 64]
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias = False), # --> [b, 256, 32, 32]
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias = False), # --> [b, 512, 16, 16]
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias = False), # --> [b, 1024, 8, 8]
            nn.BatchNorm2d(ndf * 16),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 16, ndf * 32, 4, 2, 1, bias = False), # --> [b, 2048, 4, 4]
            nn.BatchNorm2d(ndf * 32),
            nn.LeakyReLU(0.2, True),
            
            conv3x3(ndf * 32, ndf * 16), # --> [b, 1024, 4, 4]
            nn.BatchNorm2d(ndf * 16),
            nn.LeakyReLU(0.2, True),
            conv3x3(ndf * 16, ndf * 8), # --> [b, 512, 4, 4]
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, True)
        ) 
        self.get_cont_logits = D_GET_LOGITS(ndf, nc)
    
    def forward(self, img):
        img_code = self.encode_img(img)
        return img_code

数据集处理:

'''dataset'''
import torchvision.transforms as transforms
import numpy as np
import pandas as pd
from PIL import Image
import pickle
from torch.utils.data import Dataset, DataLoader

class birdDataset(Dataset):
    def __init__(self, dataDir, split = 'train', imgsize = 64, transform = None):
        super(birdDataset, self).__init__()
        self.dataDir = dataDir
        self.imgsize = imgsize
        self.transform = transform
        self.bbox = self.load_bbox()
        self.filenames = self.load_filenames(dataDir, split)
        self.embeddings = self.load_embedding(dataDir, split)
        
    def load_bbox(self): 
        '''返回{图片名称:bbox}的字典
        例如:{'002.Laysan_Albatross/Laysan_Albatross_0017_614': [82.0, 75.0, 308.0, 260.0],
              '002.Laysan_Albatross/Laysan_Albatross_0018_492': [89.0, 172.0, 108.0, 146.0],
              '002.Laysan_Albatross/Laysan_Albatross_0021_737': [83.0, 124.0, 211.0, 250.0],
               .................}'''
        # 我的项目的dataDir是'../dataset'
        pth = os.path.join(self.dataDir, 'CUB_200_2011', 'bounding_boxes.txt')
        #delin_whitespace: 设置空格为分隔符,此时delimiter参数失效, header = None表示不把第一行作为头部
        bbox_data = pd.read_csv(pth, delim_whitespace= True, header=None) 
        file_path = os.path.join(self.dataDir, 'CUB_200_2011', 'images.txt')
        df_filenames = pd.read_csv(file_path, delim_whitespace=True, header=None)
        filenames = sorted(df_filenames[1].tolist())
        # filename[:-4] 表示去掉图片名称最后的.jpg
        fname_bbox_dict = {filename[:-4]:[] for filename in filenames}
        for i in range(len(filenames)):
            data = bbox_data.iloc[i][1:].tolist()
            k = filenames[i][:-4]
            fname_bbox_dict[k] = data
        return fname_bbox_dict 
    
    def get_img(self, img_path, bbox):
        img = Image.open(img_path).convert('RGB') # 读取图片,为了防止不是RGB图片,这里进行了转换
        width, height = img.size
        if bbox is not None:
            R = int(np.maximum(bbox[2], bbox[3]) * 0.75)
            center_x = int((2 * bbox[0] + bbox[2]) / 2)
            center_y = int((2 * bbox[1] + bbox[3]) / 2)
            y1 = np.maximum(0, center_y - R)
            y2 = np.minimum(height, center_y + R)
            x1 = np.maximum(0, center_x - R)
            x2 = np.minimum(width, center_x + R)
            img = img.crop([x1, y1, x2, y2])
        
        load_size = int(self.imgsize * 76 / 64)
        img = img.resize((load_size, load_size), Image.BILINEAR)
        if self.transform is not None:
            img = self.transform(img)
        return img
    
    def load_embedding(self, dataDir, split):
        embedding_filename = os.path.join(dataDir, 'birds', split, 'char-CNN-RNN-embeddings.pickle')
        with open(embedding_filename, 'rb') as f:
            embeddings = pickle.load(f, encoding='bytes') 
            embeddings = np.array(embeddings) # 返回的是list, 转为array
#             print('embeddings:', embeddings.shape) # embeddings: (8855, 10, 1024)
        return embeddings
    
    def load_filenames(self, dataDir, split):
        '''返回图片的名称:
        例如:['002.Laysan_Albatross/Laysan_Albatross_0002_1027',
              '002.Laysan_Albatross/Laysan_Albatross_0003_1033',
              '002.Laysan_Albatross/Laysan_Albatross_0082_524',
              '002.Laysan_Albatross/Laysan_Albatross_0044_784',
              ....]  
              保存的filenames.pickle文件本身不带后缀'''
        file_path = os.path.join(dataDir, 'birds', split, 'filenames.pickle')
        with open(file_path, 'rb') as f:
            filenames = pickle.load(f)
        return filenames
    
    def __getitem__(self, index):
        key = self.filenames[index] # 取出图片名称, 不含.jpg
        if self.bbox is not None:
            bbox = self.bbox[key] # 根据图片名称在franme_bbox_dict字典中找出对应的bbox
        else:
            bbox = None
        embeddings = self.embeddings[index, :, :] # 取出对应的embedding
        img_path = os.path.join(self.dataDir, 'CUB_200_2011', 'images', key+'.jpg') # 得到图片的全路径
        image = self.get_img(img_path, bbox) # 获取图片
        
        # random select a sentence
        sample = np.random.randint(0, embeddings.shape[0]-1)
        embedding = embeddings[sample, :]
        return image, embedding
    
    def __len__(self):
        return len(self.filenames)

Stage-I 训练:

device = torch.device('cuda: 0' if torch.cuda.is_available() else 'cpu')
'''开始训练STAGE_1'''
def KL_loss(mean, log_var): # 传入μ,log(σ^2)
    # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    temp = -mean**2 - torch.exp(log_var) + 1 + log_var
    KLD = torch.mean(temp) * (-0.5)
    return KLD
'''权重初始化'''
def weight_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
            
'''计算生成器的损失'''
def cal_G_loss(netD, fake_imgs, real_labels, c_code):
    criterion = nn.BCELoss()
    img_code = netD(fake_imgs) # 将图片输入netD后得到(512, 4, 4)的shape --> img_code
    img_cond_output = netD.get_cont_logits(img_code, c_code) # img_code连同c_code一起输入到判别器
    err_G = criterion(img_cond_output, real_labels)
    return err_G
'''计算判别器的损失'''
def cal_D_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, c_code):
    criterion = nn.BCELoss()
    batch_size = real_imgs.size(0)
    c_code = c_code.detach()
    fake_imgs = fake_imgs.detach()
    real_imgs_code = netD(real_imgs)
    fake_imgs_code = netD(fake_imgs)
    
    real_output = netD.get_cont_logits(real_imgs_code, c_code) # 真实图片及其配对文字输出
    err_D_real = criterion(real_output, real_labels)
    
    unmatched_output = netD.get_cont_logits(real_imgs_code[:(batch_size-1)], c_code[1:]) # 输出真实图片与文字不匹配
    err_D_unmatched = criterion(unmatched_output, real_labels[1:])
    
    fake_output = netD.get_cont_logits(fake_imgs_code, c_code) # fake图片及其配对文字输出
    err_D_fake = criterion(fake_output, fake_labels)
    
    err_D = err_D_real + (err_D_fake + err_D_unmatched) * 0.5
    
    return err_D, err_D_real.item(), err_D_unmatched.item(), err_D_fake.item()

def main():

    transform = transforms.Compose([
        transforms.RandomCrop(64),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = birdDataset(dataDir = opt.dataDir, split='train', imgsize=64, transform = transform)
    train_loader = DataLoader(train_dataset, batch_size = opt.batch_size, shuffle=True, num_workers=0)
    
    test_dataset = birdDataset(dataDir= opt.dataDir, split= 'test', imgsize=64, transform = transform)
    test_loader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0)
    
    netG = STAGE1_G().to(device)
    netD = STAGE1_D().to(device)
    
    if opt.s1_start_epoch != 1 :
        # Load pretrained models
        netG.load_state_dict(torch.load('model/stage-I/netG_epoch_{}.pth'.format(opt.s1_start_epoch)))
        netD.load_state_dict(torch.load('model/stage-I/netD_epoch_{}.pth'.format(opt.s1_start_epoch)))
    else: # Initialize weights
        netG.apply(weight_init_normal)
        netD.apply(weight_init_normal)
    
    optim_D = torch.optim.Adam(netD.parameters(), lr = opt.lr, betas=(opt.b1, opt.b2))
    optim_G = torch.optim.Adam(netG.parameters(), lr = opt.lr, betas=(opt.b1, opt.b2))
    
    fixed_noise = torch.rand(opt.batch_size, opt.z_dim, device = device)
    
    for epoch in range(opt.s1_start_epoch, opt.num_epochs+1):
        # 每过50个epoch, 学习率下降一半
        if epoch % 50 == 1 and epoch > 1:
            opt.lr = opt.lr * 0.5 
            for param_group in optim_G.param_groups:
                param_group['lr'] = opt.lr
            for param_group in optim_D.param_groups:
                param_group['lr'] = opt.lr
        for i, data in enumerate(train_loader, 0):
            real_imgs, text_embedding = data
            real_imgs = real_imgs.to(device)
            text_embedding = text_embedding.to(device)
            b_size = real_imgs.shape[0]
            
            real_labels = torch.full(size = (b_size,), fill_value = 1.0, device = device)
            fake_labels = torch.full(size = (b_size,), fill_value = 0.0, device = device)
            
            '''update discriminator'''
            netD.zero_grad()
            # genetate fake_image
            noise = torch.rand(b_size, opt.z_dim, device = device)
            _, fake_imgs, mu, log_var = netG(text_embedding, noise)
            '''这个地方是真的有疑惑,为什么论文作者在这里传入的c_code是均值mu, 就是在计算损失的时候才传入均值作为c_code'''
            err_D, err_D_real, err_D_unmatched, err_D_fake = cal_D_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels,mu)
            err_D.backward()
            optim_D.step()
            
            '''update generator'''
            netG.zero_grad()
            err_G = cal_G_loss(netD, fake_imgs, real_labels, mu)
            err_G += KL_loss(mu, log_var)
            err_G.backward()
            optim_G.step()
            
            if i % 20 == 0: # 每过50个batch就打印一次loss
                print('[Epoch %d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tLoss_D_R: %.4f\tLoss_D_Unmatch: %.4f\tLoss_D_F %.4f'
                      % (epoch, opt.num_epochs, i, len(train_loader),
                         err_D.item(), err_G.item(), err_D_real, err_D_unmatched, err_D_fake))
                
        if epoch % 10 == 0 : # 每10个epoch测试一次
            with torch.no_grad():
                real_imgs, text_embedding = next(iter(test_loader))
                real_imgs = real_imgs.to(device)
                text_embedding = text_embedding.to(device)
                # genetate fake_image
                _, fake_imgs, _, _ = netG(text_embedding, fixed_noise)
                real_imgs = torchvision.utils.make_grid(real_imgs.cpu(), nrow=8, normalize=True)
                fake_imgs = torchvision.utils.make_grid(fake_imgs.detach().cpu(), nrow=8, normalize=True)
                
                img_grid = torch.cat((real_imgs, fake_imgs), dim = 2)
                torchvision.utils.save_image(img_grid, 'picture/stage-I/{}.jpg'.format(epoch), nrow = 16, normalize=False)
            
            torch.save(netG.state_dict(), 'model/stage-I/netG_epoch_{}.pth'.format(epoch))
            torch.save(netD.state_dict(), 'model/stage-I/netD_epoch_{}.pth'.format(epoch))

if __name__ == '__main__':
    main()

Stage-II 训练:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def KL_loss(mean, log_var): # 传入μ,log(σ^2)
    # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    temp = -mean**2 - torch.exp(log_var) + 1 + log_var
    KLD = torch.mean(temp) * (-0.5)
    return KLD
'''权重初始化'''
def weight_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
            
'''计算生成器的损失'''
def cal_G_loss(netD, fake_imgs, real_labels, c_code):
    criterion = nn.BCELoss()
    img_code = netD(fake_imgs) # 将图片输入netD后得到(512, 4, 4)的shape --> img_code
    img_cond_output = netD.get_cont_logits(img_code, c_code) # img_code连同c_code一起输入到判别器
    err_G = criterion(img_cond_output, real_labels)
    return err_G
'''计算判别器的损失'''
def cal_D_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, c_code):
    criterion = nn.BCELoss()
    batch_size = real_imgs.size(0)
    c_code = c_code.detach()
    fake_imgs = fake_imgs.detach()
    real_imgs_code = netD(real_imgs)
    fake_imgs_code = netD(fake_imgs)
    
    real_output = netD.get_cont_logits(real_imgs_code, c_code) # 真实图片及其配对文字输出
    err_D_real = criterion(real_output, real_labels)
    
    unmatched_output = netD.get_cont_logits(real_imgs_code[:(batch_size-1)], c_code[1:]) # 输出真实图片与文字不匹配
    err_D_unmatched = criterion(unmatched_output, real_labels[1:])
    
    fake_output = netD.get_cont_logits(fake_imgs_code, c_code) # fake图片及其配对文字输出
    err_D_fake = criterion(fake_output, fake_labels)
    
    err_D = err_D_real + (err_D_fake + err_D_unmatched) * 0.5
    
    return err_D, err_D_real.item(), err_D_unmatched.item(), err_D_fake.item()

def main():

    transform = transforms.Compose([
        transforms.RandomCrop(256),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = birdDataset(dataDir = opt.dataDir, split='train', imgsize=256, transform = transform)
    train_loader = DataLoader(train_dataset, batch_size = opt.batch_size, shuffle=True, num_workers=0)
    
    test_dataset = birdDataset(dataDir= opt.dataDir, split= 'test', imgsize=256, transform = transform)
    test_loader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=0)
    
    G1 = STAGE1_G()
    G1.load_state_dict(torch.load('model/stage-I/netG_epoch_200.pth'))
    G1.eval() # G1 设置为测试模式
    
    netG = STAGE2_G(G1).to(device)
    netD = STAGE2_D().to(device)
    if opt.s2_start_epoch != 1 :
        # Load pretrained models
        netG.load_state_dict(torch.load('model/stage-II/netG2_epoch_{}.pth'.format(opt.s2_start_epoch)))
        netD.load_state_dict(torch.load('model/stage-II/netD2_epoch_{}.pth'.format(opt.s2_start_epoch)))
    else:
        netG.apply(weight_init_normal)
        netD.apply(weight_init_normal)
    
    optim_D = torch.optim.Adam(netD.parameters(), lr = opt.lr, betas=(opt.b1, opt.b2))
    
    # remove the parameter from Stage-I generator
    netG_param = []
    for p in netG.parameters():
        if p.requires_grad:
            netG_param.append(p)
    optim_G = torch.optim.Adam(netG_param, lr = opt.lr, betas=(opt.b1, opt.b2))
    
    fixed_noise = torch.rand(opt.batch_size, opt.z_dim, device = device)
    
    for epoch in range(opt.s2_start_epoch, opt.num_epochs+1):
        # 每过50个epoch, 学习率下降一半 (我这里和论文中不一致,论文中是100个epoch减少一次,因为我训练的epoch没有像论文中那样多,
        # 所以我自己设置了每50个epoch学习率减少一次)
        if epoch % 50 == 1 and epoch > 1:
            opt.lr = opt.lr * 0.5 
            for param_group in optim_G.param_groups:
                param_group['lr'] = opt.lr
            for param_group in optim_D.param_groups:
                param_group['lr'] = opt.lr
        for i, data in enumerate(train_loader, 0):
            real_imgs, text_embedding = data
            real_imgs = real_imgs.to(device)
            text_embedding = text_embedding.to(device)
            b_size = real_imgs.shape[0]
            
            real_labels = torch.full(size = (b_size,), fill_value = 1.0, device = device)
            fake_labels = torch.full(size = (b_size,), fill_value = 0.0, device = device)
            
            '''update discriminator'''
            netD.zero_grad()
            # genetate fake_image
            noise = torch.rand(b_size, opt.z_dim, device = device)
            stageI_img, fake_imgs, mu, log_var = netG(text_embedding, noise)
            '''这个地方是真的有疑惑,为什么论文作者在这里传入的c_code是均值mu, 就是在计算损失的时候才传入均值作为c_code'''
            err_D, err_D_real, err_D_unmatched, err_D_fake = cal_D_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels,mu)
            err_D.backward()
            optim_D.step()
            
            '''update generator'''
            netG.zero_grad()
            err_G = cal_G_loss(netD, fake_imgs, real_labels, mu)
            err_G += KL_loss(mu, log_var)
            err_G.backward()
            optim_G.step()
            
            if i % 20 == 0: # 每过50个batch就打印一次loss
                print('[Epoch %d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tLoss_D_R: %.4f\tLoss_D_Unmatch: %.4f\tLoss_D_F %.4f'
                      % (epoch, opt.num_epochs, i, len(train_loader),
                         err_D.item(), err_G.item(), err_D_real, err_D_unmatched, err_D_fake))
                
        if epoch % 10 == 0 : # 每10个epoch测试一次
            with torch.no_grad():
                real_imgs, text_embedding = next(iter(test_loader))
                real_imgs = real_imgs.to(device)
                text_embedding = text_embedding.to(device)
                # genetate fake_image
                stageI_img, fake_imgs, _, _ = netG(text_embedding, fixed_noise)
                
                real_imgs = torchvision.utils.make_grid(real_imgs.cpu(), nrow=8, normalize=True)
                fake_imgs = torchvision.utils.make_grid(fake_imgs.cpu().detach(), nrow=8, normalize=True)
                
                img_grid = torch.cat((real_imgs, fake_imgs), dim = 2)
                torchvision.utils.save_image(img_grid, 'picture/stage-II/stage2_{}.jpg'.format(epoch), nrow = 8, normalize=False)
            
            torch.save(netG.state_dict(), 'model/stage-II/netG2_epoch_{}.pth'.format(epoch))
            torch.save(netD.state_dict(), 'model/stage-II/netD2_epoch_{}.pth'.format(epoch))

if __name__ == '__main__':
    main()
六、训练结果

stage-II 部分训练过程:

在这里插入图片描述

stage-I 结果(1080ti 11G 跑了100个epoch):
epoch = 100: 左边8 * 8 = 64 是原图,右边8 * 8 是生成图
在这里插入图片描述
stage- II 结果(Tesla V100 16G 跑了100个epoch):
epoch = 100:
在这里插入图片描述

七、遇到的问题及解决

1.报错:RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 ‘self’ in call to _th_mm
每次快训练完一个epoch的时候就报这个错误

[Epoch 0/150][0/139]	Loss_D: 1.5227	Loss_G: 0.0484	Loss_D_R: 0.7135	Loss_D_Unmatch: 0.7130	Loss_D_F 0.9054
[Epoch 0/150][20/139]	Loss_D: 0.4132	Loss_G: 4.0370	Loss_D_R: 0.1938	Loss_D_Unmatch: 0.1941	Loss_D_F 0.2447
[Epoch 0/150][40/139]	Loss_D: 0.3920	Loss_G: 5.0670	Loss_D_R: 0.0967	Loss_D_Unmatch: 0.1042	Loss_D_F 0.4863
[Epoch 0/150][60/139]	Loss_D: 2.6354	Loss_G: 2.6068	Loss_D_R: 0.0025	Loss_D_Unmatch: 0.0023	Loss_D_F 5.2634
[Epoch 0/150][80/139]	Loss_D: 0.2268	Loss_G: 3.0667	Loss_D_R: 0.1089	Loss_D_Unmatch: 0.1176	Loss_D_F 0.1183
[Epoch 0/150][100/139]	Loss_D: 0.5236	Loss_G: 3.2538	Loss_D_R: 0.1222	Loss_D_Unmatch: 0.1221	Loss_D_F 0.6808
[Epoch 0/150][120/139]	Loss_D: 0.3369	Loss_G: 2.6883	Loss_D_R: 0.1911	Loss_D_Unmatch: 0.2010	Loss_D_F 0.0905
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-330-4e7aecf0f178> in <module>
    124 
    125 if __name__ == '__main__':
--> 126     main()

<ipython-input-330-4e7aecf0f178> in main()
    113                 text_embedding = text_embedding.to(device)
    114                 # genetate fake_image
--> 115                 _, fake_imgs, _, _ = netG(text_embeddings, fixed_noise)
    116                 real_imgs = torchvision.utils.make_grid(real_imgs.cpu(), nrow=8, normalize=True)
    117                 fake_imgs = torchvision.utils.make_grid(fake_imgs.detach().cpu(), nrow=8, normalize=True)

D:\myUtils\anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

<ipython-input-328-930c20581050> in forward(self, text_embedding, noise)
    117         )
    118     def forward(self, text_embedding, noise):
--> 119         c_code, mu, log_var = self.ca_net(text_embedding)
    120         z_c_code = torch.cat((noise, c_code), dim = 1)
    121         img_code = self.fc(z_c_code)

D:\myUtils\anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

<ipython-input-328-930c20581050> in forward(self, text_embedding)
     55         return mu + torch.mul(std, eps)
     56     def forward(self, text_embedding):
---> 57         mu, log_var = self.encode(text_embedding)
     58         c_code = self.reparameter(mu, log_var)
     59         return c_code, mu, log_var

<ipython-input-328-930c20581050> in encode(self, text_embedding)
     45 
     46     def encode(self, text_embedding):
---> 47         x = self.relu(self.fc(text_embedding))
     48         mu = x[:, :self.c_dim] # 前面c_dim用作均值输出
     49         log_var = x[:, self.c_dim:] # 后面c_dim用作log方差输出

D:\myUtils\anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

D:\myUtils\anaconda3\lib\site-packages\torch\nn\modules\linear.py in forward(self, input)
     85 
     86     def forward(self, input):
---> 87         return F.linear(input, self.weight, self.bias)
     88 
     89     def extra_repr(self):

D:\myUtils\anaconda3\lib\site-packages\torch\nn\functional.py in linear(input, weight, bias)
   1610         ret = torch.addmm(bias, input, weight.t())
   1611     else:
-> 1612         output = input.matmul(weight.t())
   1613         if bias is not None:
   1614             output += bias

RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_mm

解决:在对应处的text_embedding都改为text_embedding.cuda() 关于这点为什么要加入cuda很疑惑,明明在在训练阶段传入的text_embedding已经放在GPU中了。
后来在服务器上跑并没有出现这个问题,说明是我电脑的原因。
pickle.load()读取文件错误

  • 7
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 25
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 25
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值