深度学习之生成对抗式神经网络实战

GAN的目的:产生与真实数据非常相近的假数据

产生器:用于产生假数据

判别器:用于判别真假数据

训练思路:不断使产生器产生假数据,但假数据的标签为真实数据的标签以达到以假乱真的目的,

向判别器不断输入真实数据,标签为真,也不断输入假数据,标签为假,已达到真是真,假是假

的目的

不多说了,贴上代码了,先仔细看看代码然后体会上面的话,有不懂的可以评论

import numpy as np
import torch as t
from torch import nn
from torch.utils.data import Dataset,DataLoader
from torch.autograd import Variable
from torchvision import transforms
import ipdb
import tqdm
import fire
import os
import visdom
from PIL import Image
from torchnet.meter import AverageValueMeter

#first step: build your own dataset class and Config class
class Config(object):
        path = '/home/szh/DCGAN/data/faces'
        imgsize = 96
        batch_size = 2048
        max_epoch = 200
        drop_last=True
        num_workers = 4
        generator_model_path = '/home/szh/checkpoints/mygenerator_epoch180.pth'
        discriminator_model_path = '/home/szh/checkpoints/mydiscriminator_epoch180.pth'

opt = Config()

class MyDataset(Dataset):
        def __init__(self,root):
                self.imgs = [os.path.join(root,img) for img in os.listdir(root)]
        def __getitem__(self,index):
                self.transforms = transforms.Compose([
                        transforms.Resize(opt.imgsize),
                        transforms.CenterCrop(opt.imgsize),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
                        )
                img = Image.open(self.imgs[index])
                return self.transforms(img)
        def __len__(self):
                return len(self.imgs)




#define your generator class
class Generator(nn.Module):
        def __init__(self):
                super(Generator,self).__init__()
                #Hout = (Hin-1)*stride-2*padding+kernel_size
                self.net = nn.Sequential(
                        nn.ConvTranspose2d(100,64*8,4,1,0,bias=False),
                        nn.BatchNorm2d(64*8),
                        nn.ReLU(inplace=True),

                        nn.ConvTranspose2d(64*8,64*4,4,2,1,bias=False),
                        nn.BatchNorm2d(64*4),
                        nn.ReLU(inplace=True),

                        nn.ConvTranspose2d(64*4,64*2,4,2,1,bias=False),
                        nn.BatchNorm2d(64*2),
                        nn.ReLU(inplace=True),

                        nn.ConvTranspose2d(64*2,64,4,2,1,bias=False),
                        nn.BatchNorm2d(64),
                        nn.ReLU(inplace=True),

                        nn.ConvTranspose2d(64,3,5,3,1,bias=False),
                        nn.Tanh()
                        )

        def forward(self,x):
                return self.net(x)

mygenerator = nn.DataParallel(Generator().cuda(),device_ids=[0,1,2,3])

#define your discriminator class
class Discriminator(nn.Module):
        def __init__(self):
                super(Discriminator,self).__init__()
                #Hin = (Hout-1)*stride-2*padding+kernel_size
                self.net = nn.Sequential(
                        nn.Conv2d(3,64,5,3,1,bias=False),
                        nn.BatchNorm2d(64),
                        nn.ReLU(inplace=True),

                        nn.Conv2d(64,64*2,4,2,1,bias=False),
                        nn.BatchNorm2d(64*2),
                        nn.ReLU(inplace=True),

                        nn.Conv2d(64*2,64*4,4,2,1,bias=False),
                        nn.BatchNorm2d(64*4),
                        nn.ReLU(inplace=True),

                        nn.Conv2d(64*4,64*8,4,2,1,bias=False),
                        nn.BatchNorm2d(64*8),
                        nn.ReLU(inplace=True),

                        nn.Conv2d(64*8,1,4,1,0,bias=False),
                        nn.Sigmoid())

        def forward(self,x):
                return self.net(x).view(-1)

mydiscriminator = nn.DataParallel(Discriminator().cuda(),device_ids=[0,1,2,3])

#train
def train(**kwargs):
        #parse your hyperparameters
        for k_,v_ in kwargs.items():
                setattr(opt,k_,v_)
        #load your data
        mydataset = MyDataset(opt.path)
        mydataloader = DataLoader(mydataset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.num_workers,drop_last=True)
        #initialization of visualization
        vis = visdom.Visdom(env='szh')
        loss_gmeter = AverageValueMeter()
        loss_dmeter = AverageValueMeter()
        x_value = 0
        #true label、false label and noises
        true_labels = Variable(t.ones(opt.batch_size))
        false_labels = Variable(t.zeros(opt.batch_size))
        noises = Variable(t.randn(opt.batch_size,100,1,1))
        #define your optimizer and loss function
        generator_optimizer = t.optim.Adam(mygenerator.parameters(),lr=2e-4,betas=(0.5,0.999))
        discriminator_optimizer = t.optim.Adam(mydiscriminator.parameters(),lr=2e-4,betas=(0.5,0.999))
        criterion = nn.BCELoss()
        #use gpu
        if t.cuda.is_available:
                mygenerator.cuda()
                mydiscriminator.cuda()
                criterion.cuda()
                true_labels,false_labels = true_labels.cuda(),false_labels.cuda()
                noises = noises.cuda()
        #start training
        for i,epoch in enumerate(tqdm.tqdm(range(opt.max_epoch))):
                for ii,x in enumerate(tqdm.tqdm(mydataloader)):
                        #train discriminator every time
                        discriminator_optimizer.zero_grad()
                        output = mydiscriminator(Variable(x))
                        loss_real = criterion(output,true_labels)
                        loss_real.backward()
                        gen_img = mygenerator(Variable(t.randn(opt.batch_size,100,1,1).cuda()))
                        output = mydiscriminator(gen_img)
                        loss_false = criterion(output,false_labels)
                        loss_false.backward()
                        discriminator_optimizer.step()
                        loss = loss_real + loss_false
                        loss_dmeter.add(loss.data[0])

                        #train generator every five times
                        if ii%5==0:
                                generator_optimizer.zero_grad()
                                gen_img = mygenerator(Variable(t.randn(opt.batch_size,100,1,1).cuda()))
                                output = mydiscriminator(gen_img)
                                loss_ = criterion(output,true_labels)
                                loss_.backward()
                                generator_optimizer.step()
                                loss_gmeter.add(loss_.data[0])
                        if ii%20==0:
                                vis.line(Y=np.array([loss_gmeter.value()[0]]), X=np.array([x_value]),
                                    win=('g_loss'),
                                    opts=dict(title='g_loss'),
                                    update=None if x_value == 0 else 'append'
                                    )
                                vis.line(Y=np.array([loss_dmeter.value()[0]]), X=np.array([x_value]),
                                    win=('d_loss'),
                                    opts=dict(title='d_loss'),
                                    update=None if x_value == 0 else 'append'
                                    )
                                x_value += 1
                #visualize results every 20 epochs and save model
                if i%20 == 0:
                        vis.images(gen_img.data.cpu().numpy()[:64]*0.5+0.5,win='fake')
                        vis.images(x.cpu().numpy()[:64]*0.5+0.5,win='real')
                        t.save(mygenerator.state_dict(),'checkpoints/mygenerator_epoch%s.pth'%epoch)
                        t.save(mydiscriminator.state_dict(),'checkpoints/mydiscriminator_epoch%s.pth'%epoch)

def generate():
        vis = visdom.Visdom(env='szh')
        map_location = lambda storage,loc:storage
        #if you want to load the model then plus the sentence 'nn.DataParallel' otherwise an exception is thrown
        testgenerator = nn.DataParallel(Generator().eval().cuda(),device_ids=[0,1,2,3])
        testdiscriminator = nn.DataParallel(Discriminator().eval().cuda(),device_ids=[0,1,2,3])
        #load model state
        testgenerator.load_state_dict(t.load(opt.generator_model_path,map_location=map_location))
        testdiscriminator.load_state_dict(t.load(opt.discriminator_model_path,map_location=map_location))
        #test 100 noises
        noises = Variable(t.randn(100,100,1,1)).cuda()
        gen_img = testgenerator(noises)
        output = testdiscriminator(gen_img)
        #get top10 indexs
        indexs = output.data.topk(10)[1]
        results = []
        for index in indexs:
                results.append(gen_img.data[index])
        vis.images(t.stack(results).cpu().numpy()*0.5+0.5,win='fake')


if __name__ == '__main__':
        fire.Fire()


#how to run
'''
execute in current window:
python -m visdom.server
execute in another window if you want to train:
python mygan.py train --path=[your images path] --batch-size=[your bacth_size]
execute in another window if you want to test:
python mygan.py generate
'''


  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值