看了以后感觉还行的关于生成对抗网络的一个GitHub项目,安利给大家
链接:https://github.com/kwotsin/mimicry
文档:https://mimicry.readthedocs.io/en/latest/guides/introduction.html
比较不错的参考链接:https://www.cnblogs.com/wanghui-garcia/p/10785579.html
这个项目时以python包的形式发布了的,直接可以用pip安装pip install torch-mimicry
安装完了就可以用下面的代码开始训练了,简直不要太容易。默认是训练cifar10数据的
代码官网有,感觉很简单,先放一下,后面更新可能会改
train.py
import torch
import torch.optim as optim
import torch_mimicry as mmc
from torch_mimicry.nets import sngan
# Data handling objects
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
dataset = mmc.datasets.load_dataset(root='./datasets', name='cifar10')
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=64, shuffle=True, num_workers=0)
# Define models and optimizers
netG = sngan.SNGANGenerator32().to(device)
netD = sngan.SNGANDiscriminator32().to(device)
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))
# Start training
trainer = mmc.training.Trainer(
netD=netD,
netG=netG,
optD=optD,
optG=optG,
n_dis=5,
num_steps=100000,
lr_decay='linear',
dataloader=dataloader,
log_dir='./log/example',
device=device)
trainer.train()
# Evaluate fid
mmc.metrics.evaluate(
metric='fid',
log_dir='./log/example',
netG=netG,
dataset_name='cifar10',
num_real_samples=50000,
num_fake_samples=50000,
evaluate_step=100000,
device=device)
# Evaluate kid
mmc.metrics.evaluate(
metric='kid',
log_dir='./log/example',
netG=netG,
dataset_name='cifar10',
num_subsets=50,
subset_size=1000,
evaluate_step=100000,
device=device)
# Evaluate inception score
mmc.metrics.evaluate(
metric='inception_score',
log_dir='./log/example',
netG=netG,
num_samples=50000,
evaluate_step=100000,
device=device)
这里会有一个bug,就是训练到50次的时候回提示一个缺少路径的bug,缺少红线标的errD还是errG我忘了,最好还是手动新建一下吧
接下来根据以上第三个链接给出的练习,我试着训练链接里提供的卡通头像,训练脚本需要改的地方其实不多,只需要改数据读入的地方就行了,更改后的代码如下:
import torch
import torch.optim as optim
import torch_mimicry as mmc
from torch_mimicry.nets import sngan
import torchvision as tv
# Data handling objects
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
image_size = 96
data_path = "./own_data" #这里注意一下,里面放的是一个文件夹face,即./own_data/face,face里面就是所有图像了
batch_size = 256
num_workers = 0
transforms = tv.transforms.Compose([
tv.transforms.Resize(image_size),
tv.transforms.CenterCrop(image_size),
tv.transforms.ToTensor(),
tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = tv.datasets.ImageFolder(data_path, transform=transforms)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True
)
# Define models and optimizers
netG = sngan.SNGANGenerator32().to(device)
netD = sngan.SNGANDiscriminator32().to(device)
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))
# Start training
trainer = mmc.training.Trainer(
netD=netD,
netG=netG,
optD=optD,
optG=optG,
n_dis=5,
num_steps=100000,
lr_decay='linear',
dataloader=dataloader,
log_dir='./log/example',
device=device)
trainer.train()
# Evaluate fid
mmc.metrics.evaluate(
metric='fid',
log_dir='./log/example',
netG=netG,
dataset_name='cifar10',
num_real_samples=50000,
num_fake_samples=50000,
evaluate_step=100000,
device=device)
# Evaluate kid
mmc.metrics.evaluate(
metric='kid',
log_dir='./log/example',
netG=netG,
dataset_name='cifar10',
num_subsets=50,
subset_size=1000,
evaluate_step=100000,
device=device)
# Evaluate inception score
mmc.metrics.evaluate(
metric='inception_score',
log_dir='./log/example',
netG=netG,
num_samples=50000,
evaluate_step=100000,
device=device)
感觉效果不会很好的样子,未完待续。。。
中间训练过程出现了问题,有参数一直不变,有参数就一直在很小范围内波动,求助啊,希望有人指导一下啊,我训练放了两个类别,一个是人脸,一个是卡通脸
数据存放目录