入坑生成对抗网络(GAN)

37 篇文章 31 订阅
3 篇文章 1 订阅

看了以后感觉还行的关于生成对抗网络的一个GitHub项目,安利给大家
链接:https://github.com/kwotsin/mimicry
文档:https://mimicry.readthedocs.io/en/latest/guides/introduction.html
比较不错的参考链接:https://www.cnblogs.com/wanghui-garcia/p/10785579.html
这个项目时以python包的形式发布了的,直接可以用pip安装pip install torch-mimicry
安装完了就可以用下面的代码开始训练了,简直不要太容易。默认是训练cifar10数据的
代码官网有,感觉很简单,先放一下,后面更新可能会改
train.py

import torch
import torch.optim as optim
import torch_mimicry as mmc
from torch_mimicry.nets import sngan

# Data handling objects
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
dataset = mmc.datasets.load_dataset(root='./datasets', name='cifar10')
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=64, shuffle=True, num_workers=0)

# Define models and optimizers
netG = sngan.SNGANGenerator32().to(device)
netD = sngan.SNGANDiscriminator32().to(device)
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))

# Start training
trainer = mmc.training.Trainer(
    netD=netD,
    netG=netG,
    optD=optD,
    optG=optG,
    n_dis=5,
    num_steps=100000,
    lr_decay='linear',
    dataloader=dataloader,
    log_dir='./log/example',
    device=device)
trainer.train()

# Evaluate fid
mmc.metrics.evaluate(
    metric='fid',
    log_dir='./log/example',
    netG=netG,
    dataset_name='cifar10',
    num_real_samples=50000,
    num_fake_samples=50000,
    evaluate_step=100000,
    device=device)

# Evaluate kid
mmc.metrics.evaluate(
    metric='kid',
    log_dir='./log/example',
    netG=netG,
    dataset_name='cifar10',
    num_subsets=50,
    subset_size=1000,
    evaluate_step=100000,        
    device=device)

# Evaluate inception score
mmc.metrics.evaluate(
    metric='inception_score',
    log_dir='./log/example',
    netG=netG,
    num_samples=50000,
    evaluate_step=100000,        
    device=device)

这里会有一个bug,就是训练到50次的时候回提示一个缺少路径的bug,缺少红线标的errD还是errG我忘了,最好还是手动新建一下吧
错误说明

接下来根据以上第三个链接给出的练习,我试着训练链接里提供的卡通头像,训练脚本需要改的地方其实不多,只需要改数据读入的地方就行了,更改后的代码如下:

import torch
import torch.optim as optim
import torch_mimicry as mmc
from torch_mimicry.nets import sngan
import torchvision as tv

# Data handling objects
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
image_size = 96
data_path = "./own_data"  #这里注意一下,里面放的是一个文件夹face,即./own_data/face,face里面就是所有图像了
batch_size = 256
num_workers = 0
transforms = tv.transforms.Compose([
        tv.transforms.Resize(image_size),
        tv.transforms.CenterCrop(image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

dataset = tv.datasets.ImageFolder(data_path, transform=transforms)
dataloader = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=num_workers,
                                     drop_last=True
                                     )


# Define models and optimizers
netG = sngan.SNGANGenerator32().to(device)
netD = sngan.SNGANDiscriminator32().to(device)
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))

# Start training
trainer = mmc.training.Trainer(
    netD=netD,
    netG=netG,
    optD=optD,
    optG=optG,
    n_dis=5,
    num_steps=100000,
    lr_decay='linear',
    dataloader=dataloader,
    log_dir='./log/example',
    device=device)
trainer.train()

# Evaluate fid
mmc.metrics.evaluate(
    metric='fid',
    log_dir='./log/example',
    netG=netG,
    dataset_name='cifar10',
    num_real_samples=50000,
    num_fake_samples=50000,
    evaluate_step=100000,
    device=device)

# Evaluate kid
mmc.metrics.evaluate(
    metric='kid',
    log_dir='./log/example',
    netG=netG,
    dataset_name='cifar10',
    num_subsets=50,
    subset_size=1000,
    evaluate_step=100000,        
    device=device)

# Evaluate inception score
mmc.metrics.evaluate(
    metric='inception_score',
    log_dir='./log/example',
    netG=netG,
    num_samples=50000,
    evaluate_step=100000,        
    device=device)

感觉效果不会很好的样子,未完待续。。。

中间训练过程出现了问题,有参数一直不变,有参数就一直在很小范围内波动,求助啊,希望有人指导一下啊,我训练放了两个类别,一个是人脸,一个是卡通脸
训练过程
数据存放目录
数据结构

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

如雾如电

随缘

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值