ds证据理论python实现_EBGANPT模型理论以及Python实现

其实也是EBGAN论文中提到的模型,只不过因为篇幅问题,所以拆分开来了。

CNN之父YannLeCun(还有另外两位)提出的一个GAN模型。虽然说是的Energe-Based,但是实际上用EB的方式理解,有点困难。反而从模型的设计上看的话,从另外一个角度来说会更加好理解。

注意看之前的EBGAN模型,会发现生成的图片数字不一定将所有的数字种类都囊括住了。或者是说,某些数字的图片出现的概率非常低。原因是,在Encoder-Decoder这样的映射当中,因为只是考虑到两个域之间的联系,所以完全可以只需要映射到输出空间的部分域内比如(数字1数字9等等)。

但是这样的结果并不能满足需求。EBGAN-PT中引入PT的概念:

043411525d2924a4104b49497f97bdba.png

PT就是为了保证不同的输入,Encoder之后的内容上显得更加正交化(互质)。换句话说就是,数值不一样的数字(这里假设了在mini-batch下各个数字应该大概率互不相同,现实中可能并不是这样的,算是这里算法那的槽点吧),应该Encoder的编码结果应该有点区别(如果没有区别,那输出可能就是相同的了)。

直观来看,这样的会使得输出分布更加均匀,如果分得更加细,那么学习到的特征就更加明显,所以图像效果会更加好。



实验

但是正如我之前所说,由于PT设计本来就存在的假设是有问题的,所以效果就存在一定的问题。

但是确实数字的分布更加分散了。

452f8dfede6b7acd958203a729f0915f.png

f4c7fe438f7c203d079b43eced7abe43.png

d9c03f76319484d9a07c136098afaa09.png

bf3c82e654a93331d1b2eb5fcf0752f6.png

3b994962743b79a970de12614c4d1a0b.png

main.py

import osimport torchfrom torch.utils.data import Dataset, DataLoaderimport torch.nn as nnfrom model import Generator, Discriminatorimport torchvisionimport matplotlib.pyplot as pltdef PT(code):    norm = torch.norm(code, dim=1, p=2).unsqueeze(1)    code_norm = code / norm    mul_code = torch.matmul(code_norm, code_norm.transpose(1, 0))    N = code.shape[0]    pt_loss = (torch.sum(mul_code) - N) / (N * (N-1))  # the diag of normalized matrix is 1s.    return pt_lossif __name__ == '__main__':    LR = 0.0002    EPOCH = 20  # 50    BATCH_SIZE = 100    N_IDEAS = 100    DOWNLOAD_MNIST = False    m = 1    mnist_root = '../Conditional-GAN/mnist/'    if not (os.path.exists(mnist_root)) or not os.listdir(mnist_root):        # not mnist dir or mnist is empyt dir        DOWNLOAD_MNIST = True    train_data = torchvision.datasets.MNIST(        root=mnist_root,        train=True,  # this is training data        transform=torchvision.transforms.ToTensor(),  # Converts a PIL.Image or numpy.ndarray to        # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]        download=DOWNLOAD_MNIST,    )    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)    torch.cuda.empty_cache()    G = Generator(N_IDEAS).cuda()    D = Discriminator(1).cuda()    optimizerG = torch.optim.Adam(G.parameters(), lr=LR)    optimizerD = torch.optim.Adam(D.parameters(), lr=LR)    mse = nn.MSELoss().cuda()    for epoch in range(EPOCH):        tmpD, tmpG = 0, 0        for step, (x, y) in enumerate(train_loader):            x = x.cuda()            rand_noise = torch.randn((x.shape[0], N_IDEAS, 1, 1)).cuda()            G_imgs = G(rand_noise)            D_fake_decode, fake_code = D(G_imgs)            D_real_decode, real_code = D(x)            D_fake = mse(G_imgs, D_fake_decode)            D_real = mse(x, D_real_decode)            D_loss = D_real + torch.clamp(m - D_fake, min=0)            G_loss = D_fake + 0.1 * PT(fake_code)            optimizerD.zero_grad()            D_loss.backward(retain_graph=True)            optimizerD.step()            optimizerG.zero_grad()            G_loss.backward(retain_graph=True)            optimizerG.step()            tmpD_ = D_loss.cpu().detach().data            tmpG_ = G_loss.cpu().detach().data            tmpD += tmpD_            tmpG += tmpG_        tmpD /= (step + 1)        tmpG /= (step + 1)        print(            'epoch %d avg of loss: D: %.6f, G: %.6f' % (epoch, tmpD, tmpG)        )        if epoch % 2 == 0:            plt.imshow(torch.squeeze(G_imgs[0].cpu().detach()).numpy())            plt.show()    torch.save(G, 'G.pkl')    torch.save(D, 'D.pkl')

model.py

import osimport torchimport torch.nn as nnimport torch.utils.data as Dataimport torchvisionfrom torch.utils.data import DataLoaderclass Generator(nn.Module):    def __init__(self, input_size):        super(Generator, self).__init__()        strides = [1, 2, 2, 2]        padding = [0, 1, 1, 1]        channels = [input_size,                    256, 128, 64, 1]  # 1表示一维        kernels = [4, 3, 4, 4]        model = []        for i, stride in enumerate(strides):            model.append(                nn.ConvTranspose2d(                    in_channels=channels[i],                    out_channels=channels[i + 1],                    stride=stride,                    kernel_size=kernels[i],                    padding=padding[i]                )            )            if i != len(strides) - 1:                model.append(                    nn.BatchNorm2d(channels[i + 1])                )                model.append(                    nn.ReLU()                )            else:                model.append(                    nn.Tanh()                )        self.main = nn.Sequential(*model)    def forward(self, x):        x = self.main(x)        return xclass Discriminator(nn.Module):    def __init__(self, input_size):        super(Discriminator, self).__init__()        strides = [2, 2]        padding = [1, 1]        channels = [input_size,                    64, 128]        kernels = [4, 4]        model = []        for i, stride in enumerate(strides):            model.append(                nn.Conv2d(                    in_channels=channels[i],                    out_channels=channels[i + 1],                    stride=stride,                    kernel_size=kernels[i],                    padding=padding[i]                )            )            model.append(                nn.BatchNorm2d(channels[i + 1])            )            model.append(                nn.LeakyReLU(0.2)            )        self.Encoder = nn.Sequential(*model)        strides = [2, 2]        padding = [1, 1]        channels = [128, 64, input_size]        kernels = [4, 4]        model = []        for i, stride in enumerate(strides):            model.append(                nn.ConvTranspose2d(                    in_channels=channels[i],                    out_channels=channels[i + 1],                    stride=stride,                    kernel_size=kernels[i],                    padding=padding[i]                )            )            model.append(                nn.BatchNorm2d(channels[i + 1])            )            model.append(                nn.LeakyReLU(0.2)            )        self.Decoder = nn.Sequential(*model)    def forward(self, x):        f_x = self.Encoder(x)        x = self.Decoder(f_x)        return x, f_x.view(x.shape[0], -1)def PT(code):    norm = torch.norm(code, dim=1, p=2).unsqueeze(1)    code_norm = code / norm    mul_code = torch.matmul(code_norm, code_norm.transpose(1, 0))    N = code.shape[0]    pt_loss = (torch.sum(mul_code) - N) / (N * (N-1))  # the diag of normalized matrix is 1s.    return pt_lossif __name__ == '__main__':    N_IDEAS = 100    G = Generator(N_IDEAS)    rand_noise = torch.randn((10, N_IDEAS, 1, 1))    print(G(rand_noise).shape)    DOWNLOAD_MNIST = False    mnist_root = '../Conditional-GAN/mnist/'    if not (os.path.exists(mnist_root)) or not os.listdir(mnist_root):        # not mnist dir or mnist is empyt dir        DOWNLOAD_MNIST = True    train_data = torchvision.datasets.MNIST(        root=mnist_root,        train=True,  # this is training data        transform=torchvision.transforms.ToTensor(),  # Converts a PIL.Image or numpy.ndarray to        # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]        download=DOWNLOAD_MNIST,    )    D = Discriminator(1)    print(len(train_data))    train_loader = Data.DataLoader(dataset=train_data, batch_size=2, shuffle=True)    for step, (x, y) in enumerate(train_loader):        print(x.shape)        print(x.max(), x.min())        dx, dx_code = D(x)        print(dx.shape, dx_code.shape)        print(PT(dx_code))        break

judge.py

import numpy as npimport torchimport matplotlib.pyplot as pltfrom model import Generator, Discriminatorimport torchvision.utils as vutilsif __name__ == '__main__':    BATCH_SIZE = 100    N_IDEAS = 100    img_shape = (1, 28, 28)    TIME = 5    G = torch.load("G.pkl").cuda()    for t in range(TIME):        rand_noise = torch.randn((BATCH_SIZE, N_IDEAS, 1, 1)).cuda()        G_imgs = G(rand_noise).cpu().detach()        fig = plt.figure(figsize=(10, 10))        plt.axis("off")        plt.imshow(np.transpose(vutils.make_grid(G_imgs, nrow=10, padding=0, normalize=True, scale_each=True), (1, 2, 0)))        plt.savefig(str(t) + '-PT.png')        plt.show()

bc991e26ec27e19eb3329ef608d213ba.png

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值