其实也是EBGAN论文中提到的模型,只不过因为篇幅问题,所以拆分开来了。
CNN之父YannLeCun(还有另外两位)提出的一个GAN模型。虽然说是的Energe-Based,但是实际上用EB的方式理解,有点困难。反而从模型的设计上看的话,从另外一个角度来说会更加好理解。
注意看之前的EBGAN模型,会发现生成的图片数字不一定将所有的数字种类都囊括住了。或者是说,某些数字的图片出现的概率非常低。原因是,在Encoder-Decoder这样的映射当中,因为只是考虑到两个域之间的联系,所以完全可以只需要映射到输出空间的部分域内比如(数字1数字9等等)。
但是这样的结果并不能满足需求。EBGAN-PT中引入PT的概念:
PT就是为了保证不同的输入,Encoder之后的内容上显得更加正交化(互质)。换句话说就是,数值不一样的数字(这里假设了在mini-batch下各个数字应该大概率互不相同,现实中可能并不是这样的,算是这里算法那的槽点吧),应该Encoder的编码结果应该有点区别(如果没有区别,那输出可能就是相同的了)。
直观来看,这样的会使得输出分布更加均匀,如果分得更加细,那么学习到的特征就更加明显,所以图像效果会更加好。
实验
但是正如我之前所说,由于PT设计本来就存在的假设是有问题的,所以效果就存在一定的问题。
但是确实数字的分布更加分散了。
main.py
import osimport torchfrom torch.utils.data import Dataset, DataLoaderimport torch.nn as nnfrom model import Generator, Discriminatorimport torchvisionimport matplotlib.pyplot as pltdef PT(code): norm = torch.norm(code, dim=1, p=2).unsqueeze(1) code_norm = code / norm mul_code = torch.matmul(code_norm, code_norm.transpose(1, 0)) N = code.shape[0] pt_loss = (torch.sum(mul_code) - N) / (N * (N-1)) # the diag of normalized matrix is 1s. return pt_lossif __name__ == '__main__': LR = 0.0002 EPOCH = 20 # 50 BATCH_SIZE = 100 N_IDEAS = 100 DOWNLOAD_MNIST = False m = 1 mnist_root = '../Conditional-GAN/mnist/' if not (os.path.exists(mnist_root)) or not os.listdir(mnist_root): # not mnist dir or mnist is empyt dir DOWNLOAD_MNIST = True train_data = torchvision.datasets.MNIST( root=mnist_root, train=True, # this is training data transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0] download=DOWNLOAD_MNIST, ) train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) torch.cuda.empty_cache() G = Generator(N_IDEAS).cuda() D = Discriminator(1).cuda() optimizerG = torch.optim.Adam(G.parameters(), lr=LR) optimizerD = torch.optim.Adam(D.parameters(), lr=LR) mse = nn.MSELoss().cuda() for epoch in range(EPOCH): tmpD, tmpG = 0, 0 for step, (x, y) in enumerate(train_loader): x = x.cuda() rand_noise = torch.randn((x.shape[0], N_IDEAS, 1, 1)).cuda() G_imgs = G(rand_noise) D_fake_decode, fake_code = D(G_imgs) D_real_decode, real_code = D(x) D_fake = mse(G_imgs, D_fake_decode) D_real = mse(x, D_real_decode) D_loss = D_real + torch.clamp(m - D_fake, min=0) G_loss = D_fake + 0.1 * PT(fake_code) optimizerD.zero_grad() D_loss.backward(retain_graph=True) optimizerD.step() optimizerG.zero_grad() G_loss.backward(retain_graph=True) optimizerG.step() tmpD_ = D_loss.cpu().detach().data tmpG_ = G_loss.cpu().detach().data tmpD += tmpD_ tmpG += tmpG_ tmpD /= (step + 1) tmpG /= (step + 1) print( 'epoch %d avg of loss: D: %.6f, G: %.6f' % (epoch, tmpD, tmpG) ) if epoch % 2 == 0: plt.imshow(torch.squeeze(G_imgs[0].cpu().detach()).numpy()) plt.show() torch.save(G, 'G.pkl') torch.save(D, 'D.pkl')
model.py
import osimport torchimport torch.nn as nnimport torch.utils.data as Dataimport torchvisionfrom torch.utils.data import DataLoaderclass Generator(nn.Module): def __init__(self, input_size): super(Generator, self).__init__() strides = [1, 2, 2, 2] padding = [0, 1, 1, 1] channels = [input_size, 256, 128, 64, 1] # 1表示一维 kernels = [4, 3, 4, 4] model = [] for i, stride in enumerate(strides): model.append( nn.ConvTranspose2d( in_channels=channels[i], out_channels=channels[i + 1], stride=stride, kernel_size=kernels[i], padding=padding[i] ) ) if i != len(strides) - 1: model.append( nn.BatchNorm2d(channels[i + 1]) ) model.append( nn.ReLU() ) else: model.append( nn.Tanh() ) self.main = nn.Sequential(*model) def forward(self, x): x = self.main(x) return xclass Discriminator(nn.Module): def __init__(self, input_size): super(Discriminator, self).__init__() strides = [2, 2] padding = [1, 1] channels = [input_size, 64, 128] kernels = [4, 4] model = [] for i, stride in enumerate(strides): model.append( nn.Conv2d( in_channels=channels[i], out_channels=channels[i + 1], stride=stride, kernel_size=kernels[i], padding=padding[i] ) ) model.append( nn.BatchNorm2d(channels[i + 1]) ) model.append( nn.LeakyReLU(0.2) ) self.Encoder = nn.Sequential(*model) strides = [2, 2] padding = [1, 1] channels = [128, 64, input_size] kernels = [4, 4] model = [] for i, stride in enumerate(strides): model.append( nn.ConvTranspose2d( in_channels=channels[i], out_channels=channels[i + 1], stride=stride, kernel_size=kernels[i], padding=padding[i] ) ) model.append( nn.BatchNorm2d(channels[i + 1]) ) model.append( nn.LeakyReLU(0.2) ) self.Decoder = nn.Sequential(*model) def forward(self, x): f_x = self.Encoder(x) x = self.Decoder(f_x) return x, f_x.view(x.shape[0], -1)def PT(code): norm = torch.norm(code, dim=1, p=2).unsqueeze(1) code_norm = code / norm mul_code = torch.matmul(code_norm, code_norm.transpose(1, 0)) N = code.shape[0] pt_loss = (torch.sum(mul_code) - N) / (N * (N-1)) # the diag of normalized matrix is 1s. return pt_lossif __name__ == '__main__': N_IDEAS = 100 G = Generator(N_IDEAS) rand_noise = torch.randn((10, N_IDEAS, 1, 1)) print(G(rand_noise).shape) DOWNLOAD_MNIST = False mnist_root = '../Conditional-GAN/mnist/' if not (os.path.exists(mnist_root)) or not os.listdir(mnist_root): # not mnist dir or mnist is empyt dir DOWNLOAD_MNIST = True train_data = torchvision.datasets.MNIST( root=mnist_root, train=True, # this is training data transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0] download=DOWNLOAD_MNIST, ) D = Discriminator(1) print(len(train_data)) train_loader = Data.DataLoader(dataset=train_data, batch_size=2, shuffle=True) for step, (x, y) in enumerate(train_loader): print(x.shape) print(x.max(), x.min()) dx, dx_code = D(x) print(dx.shape, dx_code.shape) print(PT(dx_code)) break
judge.py
import numpy as npimport torchimport matplotlib.pyplot as pltfrom model import Generator, Discriminatorimport torchvision.utils as vutilsif __name__ == '__main__': BATCH_SIZE = 100 N_IDEAS = 100 img_shape = (1, 28, 28) TIME = 5 G = torch.load("G.pkl").cuda() for t in range(TIME): rand_noise = torch.randn((BATCH_SIZE, N_IDEAS, 1, 1)).cuda() G_imgs = G(rand_noise).cpu().detach() fig = plt.figure(figsize=(10, 10)) plt.axis("off") plt.imshow(np.transpose(vutils.make_grid(G_imgs, nrow=10, padding=0, normalize=True, scale_each=True), (1, 2, 0))) plt.savefig(str(t) + '-PT.png') plt.show()