ds证据理论python实现_DualGAN模型理论以及Python实现

http://arxiv.org/abs/1704.02510

还记得前段时间请假了一天嘛,那个就是为了完成这个实验,一直都没有完成。甚至我在网上找到的模型都是失败的emmmmm我哭了,上传之前那些大哥们都没有检查过自己的代码的嘛???果然我还是个老实人。

下面就就让本憨憨老实人, 接着讲讲我读DualGAN的内容。

这是一篇2017年的ICCV。

当时选这个的原因很简单,D在文件夹中顺序靠前emmmm。因为当时已经完成了很多相关的GAN,意气风发。就想按照顺序来,结果,受到了DualGAN的毒打。(哭了)

我先给大家理清一下关系:

DualGAN是基于CycleGAN,pix2pix,coupledGAN的思想提出的。(最早在这个领域的应该是coupledGAN尝试用多个GAN协作的方式完成训练)

可以说是多个GAN协作(或者对抗)这个领域的一个作品。

同时,在训练的稳定性上,又使用了WGAN。当时我还没看WGAN,但是看到WGAN20+页的论文,外加上里面的数学极其多,就想着先实现再说,反正WGAN这些为了保证稳定性的创新,最终都会体现在设计上,蛮好解决的。(理论反正可以后面再补)

哦哦对了,DualGAN延伸到pix2pix之后,发现G的生成器结构,又是拿医学图像分割模型U-Net的(同样没看过)。

也就是说相关的论文,当时除了coupledGAN之外其他我都是没学过的。

但是,膨胀的我还是愿意接受挑战。接着,我就顺着线,大致的快读了一下这几篇论文,然后开始实现。

效果非常糟糕,生成的图像中基本都是噪声。我当时就蒙了(毒打一)。在疯狂试了一堆之后,发现还是很糟糕。

但是仔细看这篇论文,发现在DualGAN的整体结构设计上并没有太tricky的部分,实现也不难,只要是对应的位置的数据摆对就应该没什么事。但是较为难的部分应该是在WGAN的创新式的收敛方式,而这也是我最搞不懂的地方。

因为WGAN设计模型的数学思路还没有体会到,里面可能存在某些数学细节我在实现的时候没有考虑到。于是就去读WGAN论文并实现WGAN。

  • WGAN模型理论以及Python实现

就有了上面这篇文章。

看了论文之后,就发现WGAN存在一个Lipschitz约束,但是当时论文中也没细提,我实现的时候,全文都没找到那个clipping指的是啥,就跳过了。然后还有就是关于约束,只在D中使用,不然也会挂掉。

看过WGAN的都知道,Lipschitz约束基本上是保证了WGAN的损失会收敛的必要条件。

好的,我觉得我可能找到问题了。于是,又去试了一波。又是一次毒打(毒打二)又陷入了一波自我怀疑。

但后来看到还有一篇论文WGAN-gp是WGAN的改进,并在网上看到有人说WGAN发出来后,在reddit上被人怼指出问题。

我又觉得自己找到新希望。WGAN-gp走一波。

  • WGAN-gp模型理论以及Python实现

实现了WGAN-gp之后,效果感觉还行。我就又把WGAN-gp迁移到DualGAN上。

接着迎来是社会的毒打。这次的毒打可能会比较惨(毒打三)

接着我又在网上找别人的实现,我倒是要看看我跟别人的成功的实现差别在哪。仔细看了一波之后,发现emmmm。虽然模型的设计不太一样,但是直观来说模型的影响不会太大吧?但是为了验证,我把这个模型搬下来。

但是还是失败了。(社会的毒打四)

这次,我觉得可能是数据的问题。可能我测试的数据集上太大。于是,我又改模型,将目标处理的数据改小。结果就不说了(社会的毒打五)

这时候,我觉得可能是有两种可能,要么这个DualGAN是个假论文,要么就是我处理的是个假数据。但是大家都在宣传这个模型,想必更大的可能是数据不太好。

于是,我就不管了,先想办法提升模型的稳定性。也就是复现了WGAN-div。

  • WGAN-div模型理论以及Python实现

实现的时候,就发现了WGAN存在有model-sensitive的问题(论文中没有提到,但是我后期自己试了一大堆)

相对应的,我觉得我之前的WGAN和WGAN-gp的成功更大可能是巧合。

经过很多组的实验,后来检测到一些关键性的指标。BN,sigmoid,线性模型。

这些问题的检验,分别存在于:

  • 验证线性模型,BN,Sigmoid在WGAN-GP不同表现

  • 验证线性模型,BN,Sigmoid在WGAN不同表现

两篇文章当中。个人觉得虽然没有太多理论知识,也没什么太大创新性,但是觉得我自己做的这两篇文章中的实验,会很好的帮助使用WGAN这个GAN的大杀器(真的,从理论上看,WGAN超级强)。


插个广告。歇会。


至此,我发现到了WGAN上存在的各种各样的问题,也大致明白了为什么我之前实现的时候DualGAN按照论文中的走会失败。

虽然DualGAN本身的创新对我的知识储备来说没有太大变化,但是解决为什么DualGAN失败以及如何才能成功上,遇到了大量的问题,这些对我来说才是读论文并实现论文的当中的最大收获。

不过从实验的结果来看,作者能成功,我还是不太理解。毕竟WGAN在卷积上表现的实在是太糟糕。只有WGAN-gp在卷积上表现的相当不错。

真不太明白为什么他单单从WGAN的weight clipping上就能成功,可能存在有别的我不太知道的知识点吧?

总之,我这边是按照WGAN-gp的方式完成了这次的实验。

进入主题,讲讲这篇文章的主角DualGAN。

DualGAN

CoupledGAN也就是Co-GAN尝试多个GAN的协作完成任务,不同风格图片的生成(一部分学习共同特征,一部分学习差异性特征),展开了多GAN这条分线上的研究。

之后的pix2pix,CycleGAN都是其中的佼佼者。

这篇论文的设计和CycleGAN有点相似,但是却有自己的特点,效果也还不错,速度也非常快。

相比于其他的风格迁移,需要对应成对的数据。也就是一张图片以及这张图片对应的其他风格下的图片。但是这样在现实生活是很少能获得这样的数据。

一般解决这个问题有两种方法,一种就是像GAN系列的用来做风格迁移的模型一样,同时学习两种类型图片的总体分布,再在两个分布之间打通联系。另外一种,就是用某些方法来概括“风格”,代表有Gram矩阵。

这里的DualGAN本质上就是第一种。

模型框架

21c870b34cbdcdcc7d071cc3ebf08627.png

也就是说,这里的生成器扮演者从一种图片到另外一种图片的映射,其中这里用到了CGAN的思路,不过CGAN举的例子输入是label,这里换用成图片而已。

比较有创新点的是,这里在G中加入了cycle_loss的概念,也就是一个数据先通过某个损失跑到了另外一个域内,再通过另外一个生成器再跑回来。衡量跑回来的数据和初始的真实数据之间的差异。

另外,作者看了另外两篇论文之后,得出了使用L1范数来刻画这个距离会比较好。也就是下面这个,计算的时候使用的是1-范数。

efe637dc01eeb371fb1a5be9b5a5039d.png

也就是G的损失如下:同样的,用论文中的参数,貌似是实现不了的?

这里的两个λ都取10。

158feb250cd4898e09d6b0f4517d9d2e.png

D的损失函数,用的是:同样,按照这么写是不行的,要加gp。要么就不要用u-net换用MLP(线性映射)。之前检验的时候,发现MLP在纯粹的WGAN下效果是可以接受的,但是加了conv就不太行。

8457263dfa25bf40a3e8f2fabfb9b1f2.png

对了,同样结合这样的WGAN,所以也是有nc的,也就是每训练nc次D,再训练一次G。


理论结束,恰饭


实验

数据是按照论文中的去找到到的。

http://mmlab.ie.cuhk.edu.hk/archive/facesketch.html

这个地方下载对应图片数据。

总共分别有真实图片和素描图片各88张,是成对的。

但是我们用的时候,没有必要成对的用,也是可以成功的。(这个就是DualGAN比其他的风格迁移模型好的地方(以前的))

因为输入的图片两种是不一样的大小,这样需要保证两种生成器的设计要不一样,为了解决这个我统一裁剪了下(当然是手动设置不同的参数裁剪的)

效果

左边这一列表示真实数据所在列,右边表示对应的素描结果。

其中第一行的右边和第二行的左边是真实输入数据。

会发现,在数据被认为是背景,不一定是指边缘的背景,也包括人脸部没有体现明显细节的部分,会出现麻点。这可能是是因为L1-范数导致的,当然也有可能是因为图片像素精度的问题。(在之后会做补充实验中做详细验证吧~)

ee4ffff7766ccbd0a4050b710cebf767.png

4f143d7c2459f9ba27efafb275452e83.png

dd9b13c7e4759a74bacdbf28bd9b85e6.png

5c527687479d037c48bf66941a3bf858.png

6fe147e7c33cec97879e1c5769eb26a2.png

当然,我也加入了一些奇怪的真实数据来输出对应的图片。

因为是拍的真实图片(随意拍的),并没有训练数据中那么规整的位置设定会发现越有可能被判断为是背景(不单单是背后的,脸上没有细节的部分也可以算作)区域,出现灰色麻点的概率较高。

但是上面,素描图转真实图片中,这个问题好了很多。结合起来很有可能是学习到的素描可能会用灰色来描一下背景(素描中有这技术)。当然也有可能是范数或者图片精度的问题。

比如:

b2184a181ea8a1220418ddba9c94fdd5.png

e0f25c55ce496b59122a0379014baf9f.png

99d0d2df6f8a5b274078b9b35e4e1ce9.png

657fecfa79b76b1896fcf6467cc5bb00.png

ee96a34fe6b81045ba7666de87d9b245.png

代码

注意,为了提高效果,我用了网上大家都推荐的ResNet中设计的Residual Block。用残差图来提高U-net的精度。并且,我这个U-net比较小(但核心concat的思路都在)

main.py

import osimport torchfrom torch.utils.data import Dataset, DataLoaderimport torch.nn as nnfrom model import Generator, Discriminator, gp_loss# from model import gp_loss# from github_model import Generator, Discriminatorimport torchvisionfrom dataloader import MyDatasetimport matplotlib.pyplot as pltimport itertoolsimport numpy as npimport torchvision.utils as vutilsif __name__ == '__main__':    LR = 0.0002    EPOCH = 100  # 50    BATCH_SIZE = 4    # drop_rate = 0.7    nc = 2  # [2 - 4]    lamadv = 1    lamcycle = 10    lamgp = 10    TRAINED = False    path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo'    dataset_photo = MyDataset(path=path, resize=96, Len=88, img_type='jpg', sketch=False)    path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_sketch\sketch'    dataset_sketch = MyDataset(path=path, resize=96, Len=88, img_type='jpg', sketch=True)    photo_loader = DataLoader(dataset=dataset_photo, batch_size=BATCH_SIZE, shuffle=True)    sketch_loader = DataLoader(dataset=dataset_sketch, batch_size=BATCH_SIZE, shuffle=True)    torch.cuda.empty_cache()    if not TRAINED:        # GA = Generator(1, drop_rate).cuda()        # GB = Generator(1, drop_rate).cuda()        GA = Generator(1).cuda()  # github version        GB = Generator(1).cuda()        DA = Discriminator(1).cuda()        DB = Discriminator(1).cuda()    else:        GA = torch.load("GA.pkl").cuda()        GB = torch.load("GB.pkl").cuda()        DA = torch.load("DA.pkl").cuda()        DB = torch.load("DB.pkl").cuda()    optimizerG = torch.optim.Adam(itertools.chain(GA.parameters(), GB.parameters()), lr=LR)    optimizerDA = torch.optim.Adam(DA.parameters(), lr=LR)    optimizerDB = torch.optim.Adam(DA.parameters(), lr=LR)    l1_c = nn.L1Loss()    for epoch in range(EPOCH):        tmpDA, tmpDB, tmpadv, tmpl1 = 0, 0, 0, 0        for step, (v, u) in enumerate(itertools.zip_longest(photo_loader, sketch_loader)):            u = u.cuda()            v = v.cuda()            ga_u = GA(u)            gb_v = GB(v)            optimizerDA.zero_grad()            optimizerDB.zero_grad()            da_ga_u = torch.squeeze(DA(ga_u))            da_v = torch.squeeze(DA(v))            # wgan            DA_loss = torch.mean(da_ga_u - da_v) + lamgp * gp_loss(DA, v, ga_u, cuda=True)            # lsgan            # DA_loss = torch.mean((da_v - 1) ** 2) + torch.mean(da_ga_u ** 2)            db_gb_v = torch.squeeze(DB(gb_v))            db_u = torch.squeeze(DB(u))            # wgan            DB_loss = torch.mean(db_gb_v - db_u) + lamgp * gp_loss(DB, u, gb_v, cuda=True)            # lsgan            # DB_loss = torch.mean((db_u - 1) ** 2) + torch.mean(db_gb_v ** 2)            D_loss = DA_loss + DB_loss            D_loss.backward(retain_graph=True)            optimizerDA.step()            optimizerDB.step()            if (step + 1) % nc == 0:                l1_U = l1_c(u, GB(ga_u))                l1_V = l1_c(v, GA(gb_v))                l1_loss = l1_U + l1_V                # wgan                adv_loss = -torch.mean(da_ga_u) - torch.mean(db_gb_v)                # lsgan                # adv_loss = torch.mean((da_ga_u - 1) ** 2) + torch.mean((db_gb_v - 1) ** 2)                G_loss = lamcycle * l1_loss + lamadv * adv_loss                optimizerG.zero_grad()                G_loss.backward(retain_graph=True)                optimizerG.step()                tmpl1_ = l1_loss.cpu().detach().data                tmpadv_ = adv_loss.cpu().detach().data                tmpl1 += tmpl1_                tmpadv += tmpadv_                tmpDA_ = DA_loss.cpu().detach().data                tmpDB_ = DB_loss.cpu().detach().data                tmpDA += tmpDA_                tmpDB += tmpDB_        tmpDA /= (step + 1)        tmpDB /= (step + 1)        tmpl1 /= (step + 1)        tmpadv /= (step + 1)        print(            'epoch %d avg of loss: DA: %.6f, DB: %.6f, G_l1: %.6f, G_adv: %.6f' % (epoch, tmpDA, tmpDB, tmpl1, tmpadv)        )        if (epoch + 1) % 5 == 0:            fig = plt.figure(figsize=(10, 10))            plt.axis("off")            plt.imshow(np.transpose(                vutils.make_grid(torch.stack([ga_u[0].cpu().detach(), u[0].cpu().detach(),                                              v[0].cpu().detach(), gb_v[0].cpu().detach()]), nrow=2, padding=0,                                 normalize=True, scale_each=True), (1, 2, 0)), cmap='gray')            plt.show()    torch.save(GA, 'GA.pkl')    torch.save(GB, 'GB.pkl')    torch.save(DA, 'DA.pkl')    torch.save(DB, 'DB.pkl')

model.py

import osimport torchimport torch.nn as nnimport torch.utils.data as Dataimport torchvisionfrom torch.utils.data import DataLoaderfrom dataloader import MyDatasetimport torch.autograd as autogradclass ResidualBlock(nn.Module):    def __init__(self, in_channel=1, out_channel=1, stride=1):        super(ResidualBlock, self).__init__()        self.weight_layer = nn.Sequential(            nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1),            nn.BatchNorm2d(out_channel),            nn.ReLU(),            nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1),        )        self.active_layer = nn.Sequential(            nn.BatchNorm2d(out_channel),            nn.ReLU()        )    def forward(self, x):        residual = x        x = self.weight_layer(x)        x += residual        return self.active_layer(x)class Generator(nn.Module):    def __init__(self, input_channel=1, drop_rate=0.5):        super(Generator, self).__init__()        self.c_e1 = nn.Sequential(nn.Conv2d(in_channels=input_channel, out_channels=64, kernel_size=4, stride=2, padding=1),                                  nn.LeakyReLU(0.2),                                  ResidualBlock(in_channel=64, out_channel=64))        self.c_e2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2),                                  nn.BatchNorm2d(128), nn.LeakyReLU(0.2),                                  ResidualBlock(in_channel=128, out_channel=128))        self.c_e3 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2),                                  nn.BatchNorm2d(256), nn.LeakyReLU(0.2),                                  ResidualBlock(in_channel=256, out_channel=256))        self.c_e4 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=4, stride=2),                                  nn.BatchNorm2d(256), nn.LeakyReLU(0.2),                                  ResidualBlock(in_channel=256, out_channel=256))        self.d_e1 = nn.Sequential(            nn.ConvTranspose2d(in_channels=128, out_channels=input_channel, kernel_size=4, stride=2, padding=1), nn.Tanh())        self.d_e2 = nn.Sequential(nn.ConvTranspose2d(in_channels=256, out_channels=64, kernel_size=4, stride=2),                                  nn.BatchNorm2d(64), nn.ReLU(),)        self.d_e3 = nn.Sequential(nn.ConvTranspose2d(in_channels=512, out_channels=128, kernel_size=5, stride=2),                                  nn.BatchNorm2d(128), nn.Dropout2d(drop_rate))        self.d_e4 = nn.Sequential(nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2),                                  nn.BatchNorm2d(256), nn.Dropout2d(drop_rate))    def forward(self, x):        e1 = self.c_e1(x)        e2 = self.c_e2(e1)        e3 = self.c_e3(e2)        e4 = self.c_e4(e3)        d4 = self.d_e4(e4)        # print(d4.shape, e3.shape)        d4 = torch.cat([d4, e3], dim=1)        # d4 = d4 + e3        d3 = self.d_e3(d4)        # print(d3.shape, e2.shape)        d3 = torch.cat([d3, e2], dim=1)        # d3 = d3 + e2        d2 = self.d_e2(d3)        # print(d2.shape, e1.shape)        d2 = torch.cat([d2, e1], dim=1)        # d2 = d2 + e1        # print(d2.shape)        d1 = self.d_e1(d2)        # print(d1.shape)        return d1class Discriminator(nn.Module):    def __init__(self, input_size):        super(Discriminator, self).__init__()        strides = [2, 2, 2, 2, 1]        padding = [0, 0, 0, 0, 0]        channels = [input_size,                    64, 128, 256, 256, 1]  # 1表示一维        kernels = [5, 5, 5, 5, 3]        model = []        for i, stride in enumerate(strides):            model.append(                nn.Conv2d(                    in_channels=channels[i],                    out_channels=channels[i + 1],                    stride=stride,                    kernel_size=kernels[i],                    padding=padding[i]                )            )            model.append(                nn.LeakyReLU(0.2)            )        # self.fc = nn.Sequential(        #     nn.Linear(9, 1),        #     # nn.Sigmoid()        # )        self.main = nn.Sequential(*model)    def forward(self, x):        x = self.main(x)        return x #.view(x.shape[0], -1)        # return self.fc(x)def gp_loss(D, real_x, fake_x, cuda=False):    if cuda:        alpha = torch.rand((real_x.shape[0], 1, 1, 1)).cuda()    else:        alpha = torch.rand((real_x.shape[0], 1, 1, 1))    x_ = (alpha * real_x + (1 - alpha) * fake_x).requires_grad_(True)    y_ = D(x_)    # cal f'(x)    grad = autograd.grad(        outputs=y_,        inputs=x_,        grad_outputs=torch.ones_like(y_),        create_graph=True,        retain_graph=True,        only_inputs=True,    )[0]    grad = grad.view(x_.shape[0], -1)    gp = ((grad.norm(2, dim=1) - 1) ** 2).mean()    return gpif __name__ == '__main__':    drop_rate = 0.5    G = Generator(1, drop_rate)    path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo'    dataset = MyDataset(path=path, resize=96, Len=10, img_type='jpg', sketch=False)    train_loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)    D = Discriminator(1)    rs = ResidualBlock(1, 1, stride=1)  # only for stride 1    for step, x in enumerate(train_loader):        print(x.shape)        print(G(x).shape)        print(D(x).shape)        print(rs(x).shape)        break

dataloader.py

import torch.utils.data as dataimport globimport osimport torchvision.transforms as transformsfrom PIL import Imageimport matplotlib.pyplot as pltimport numpy as npimport torchimport piexifimport imghdrimport numberstry:    import accimageexcept ImportError:    accimage = Noneclass MyCrop(object):    """Crops the given PIL Image at the center.    Args:        size (sequence or int): Desired output size of the crop. If size is an            int instead of sequence like (h, w), a square crop (size, size) is            made.    """    def __init__(self, i, j, size):        if isinstance(size, numbers.Number):            self.size = (int(size), int(size))        else:            self.size = size        self.i, self.j = i, j    def __call__(self, img):        """        Args:            img (PIL Image): Image to be cropped.        Returns:            PIL Image: Cropped image.        """        if not isinstance(img, Image.Image):            raise TypeError('img should be PIL Image. Got {}'.format(type(img)))        th, tw = self.size        return img.crop((self.j, self.i, self.j + tw, self.i + th))    def __repr__(self):        return self.__class__.__name__ + '(size={0})'.format(self.size)class MyDataset(data.Dataset):    def __init__(self, path, Train=True, Len=-1, resize=-1, img_type='png', remove_exif=False, sketch=True, default=False):        if resize != -1:            if default:                transform = transforms.Compose([                    transforms.Resize(resize),                    transforms.CenterCrop(resize),                    transforms.ToTensor(),                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                ])            elif sketch:                transform = transforms.Compose([                    transforms.Resize(resize),                    MyCrop(30, 0, resize),                    transforms.ToTensor(),                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                    # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))                ])            else:                transform = transforms.Compose([                    transforms.Resize(resize+20),                    MyCrop(15, 26, resize),                    transforms.ToTensor(),                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                    # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))                ])        else:            transform = transforms.Compose([                transforms.ToTensor(),            ])        img_format = '*.%s' % img_type        if remove_exif:            for name in glob.glob(os.path.join(path, img_format)):                try:                    piexif.remove(name)  # 去除exif                except Exception:                    continue        # imghdr.what(img_path) 判断是否为损坏图片        if Len == -1:            self.dataset = [np.array(transform(Image.open(name).convert("L"))) for name in                            glob.glob(os.path.join(path, img_format)) if imghdr.what(name)]        else:            self.dataset = [np.array(transform(Image.open(name).convert("L"))) for name in                            glob.glob(os.path.join(path, img_format))[:Len] if imghdr.what(name)]        self.dataset = np.array(self.dataset)        self.dataset = torch.Tensor(self.dataset)        self.Train = Train    def __len__(self):        return len(self.dataset)    def __getitem__(self, idx):        return self.dataset[idx]if __name__ == '__main__':    path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_sketch\sketch'    dataset = MyDataset(path=path, resize=96, Len=10, img_type='jpg', sketch=True)    print(len(dataset))    plt.imshow(np.squeeze(dataset[0].numpy()) * 0.5 + 0.5, cmap='gray')    plt.show()    print(dataset[0].max(), dataset[0].min())    path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo'    dataset = MyDataset(path=path, resize=96, Len=10, img_type='jpg', sketch=False)    print(len(dataset))    plt.imshow(np.squeeze(dataset[0].numpy()) * 0.5 + 0.5, cmap='gray')    plt.show()    print(dataset[0].max(), dataset[0].min())

judge.py

基于现实的数据部分,如果你代码文件夹下没有其他的.jpg文件作为输入的话,可能没有结果哦~

import numpy as npimport torchimport matplotlib.pyplot as pltfrom model import Generator, Discriminatorfrom dataloader import MyDatasetfrom torch.utils.data import Dataset, DataLoaderimport itertoolsimport torchvision.utils as vutilsif __name__ == '__main__':    BATCH_SIZE = 5    N_IDEAS = 100    img_shape = (1, 28, 28)    GA = torch.load("GA.pkl").cuda()    GB = torch.load("GB.pkl").cuda()    path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_photo\photo'    dataset_photo = MyDataset(path=path, resize=96, Len=10, img_type='jpg', sketch=False)    path = r'D:\Software\DataSet\CUHK_Face_Sketch\CUHK_training_sketch\sketch'    dataset_sketch = MyDataset(path=path, resize=96, Len=10, img_type='jpg', sketch=True)    photo_loader = DataLoader(dataset=dataset_photo, batch_size=BATCH_SIZE, shuffle=True)    sketch_loader = DataLoader(dataset=dataset_sketch, batch_size=BATCH_SIZE, shuffle=True)    for step, (v, u) in enumerate(itertools.zip_longest(photo_loader, sketch_loader)):        u = u.cuda()        v = v.cuda()        ga_u = GA(u)        gb_v = GB(v)        for i in range(BATCH_SIZE):            fig = plt.figure(figsize=(10, 10))            plt.axis("off")            plt.imshow(np.transpose(                vutils.make_grid(torch.stack([ga_u[i].cpu().detach(), u[i].cpu().detach(),                                              v[i].cpu().detach(), gb_v[i].cpu().detach()]), nrow=2, padding=0,                                 normalize=True, scale_each=True), (1, 2, 0)), cmap='gray')            plt.savefig(str(i) + '.png', dpi=300)            plt.show()        break    path = r'.'    dataset_real = MyDataset(path=path, resize=96, Len=10, img_type='jpg', default=True)    real_img = DataLoader(dataset=dataset_real, batch_size=BATCH_SIZE, shuffle=True)    for step, v in enumerate(real_img):        v = v.cuda()        gb_v = GB(v).cpu().detach()        for i in range(gb_v.shape[0]):            plt.imshow(np.squeeze(gb_v[i].numpy()), cmap='gray')            plt.savefig('real%d.png' % i, dpi=300)            plt.show()

187de1cd494e80405e89e28e3be203ae.png

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值