pix2pixHD训练自己的数据集(win10)
一、环境要求
- Linux or macOS or win10(me win10+anaconda)
- Python 2 or 3(me python3.6)
- NVIDIA GPU (11G memory or larger) + CUDA cuDNN(me cuda9.2,12G内存的GPU可以训练1024x512的图像,无法训练或测试2048x1024的图像)
二、环境配置
- 安装cuda和pytorch
- 安装dominate:pip install dominate
- 安装其他Python包,缺什么装什么。
三、源码下载
- 下载地址:https://github.com/NVIDIA/pix2pixHD
- 下载方法1:github直接下载zip,下载速度慢
- 下载方法2:借助gitee(https://gitee.com)平台下载(需要注册),会恢复到正常下载速度,即先将代码加号导入到gitee平台再下载。
- 论文地址:High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs
四、制作自己的数据集
1.图像输入尺寸:
起初我的图像大小不统一,训练时出现张量不匹配的错误,之后,我的数据集在训练之前统一缩放到了1024x512。预处理resize_or_crop的默认设置为scale_width,loadSize默认设置为1024,也就是在保持宽高比的同时将所有训练图像的宽度缩放为(1024)。如果您想要其他设置,请使用–resize_or_crop选项进行更改。例如,将resize_or_crop设置为scale_width_and_crop,就是首先将图像调整为具有宽度opt.loadSize,然后随机裁剪size (opt.fineSize, opt.fineSize)。–resize_or_crop设置为crop时仅执行随机裁剪。如果您不希望进行任何预处理,请指定none,除了确保图像可以被32整除之外,它不会做任何其他事情。
2.带有instance的数据集:
带有instance的数据集需要生成标签映射,该标签映射是单通道的,其像素值对应于对象标签(即0,1,…,N-1,其中N是标签的数量)。这是因为需要从标签图生成一个热向量。具体数据集可以参考代码中datasets/cityscapes示例数据集制作。
3.不带instance的数据集:
直接使用RGB图像作为输入的内容即可,数据集应包括train_A,train_B,test_A和test_B四个文件夹,训练从A到B的一种映射关系,A和B中的图像文件名要一一对应,如train_A标签图000000.jpg对应真值图train_B中的000000.jpg.
五、训练和测试(me ,no_instance,win10)
Windows系统没有sh文件,所以直接用Python训练测试。
官方程序在train.py和test.py文件的main函数外开启了多线程,而windows下Python开启多线程必须在内置main函数中开启,因此train.py和test.py需要修改,否则报如下错误:
RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.
This probably means that you are not using fork to start your
child processes and you have forgotten to use the proper idiom
in the main module:
if __name__ == '__main__':
freeze_support()
...
The "freeze_support()" line can be omitted if the program
is not going to be frozen to produce an executable.
1.训练程序train.py
import time
import os
import numpy as np
import torch
from torch.autograd import Variable
from collections import OrderedDict
from subprocess import call
import fractions
def lcm(a,b): return abs(a * b)/fractions.gcd(a,b) if a and b else 0
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer
def train():
opt = TrainOptions().parse()
iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
if opt.continue_train:
try:
start_epoch, epoch_iter = np.loadtxt(iter_path, delimiter=',', dtype=int)
except:
start_epoch, epoch_iter = 1, 0
print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
else:
start_epoch, epoch_iter = 1, 0
opt.print_freq = lcm(opt.print_freq, opt.batchSize)
if opt.debug:
opt.display_freq = 1
opt.print_freq = 1
opt.niter = 1
opt.niter_decay = 0
opt.max_dataset_size = 10
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)
model = create_model(opt)
visualizer = Visualizer(opt)
if opt.fp16:
from apex import amp
model, [optimizer_G, optimizer_D] = amp.initialize(model, [model.optimizer_G, model.optimizer_D],
opt_level='O1')
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
else:
optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D
total_steps = (start_epoch - 1) * dataset_size + epoch_iter
display_delta = total_steps % opt.display_freq
print_delta = total_steps % opt.print_freq
save_delta = total_steps % opt.save_latest_freq
for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
if epoch != start_epoch:
epoch_iter = epoch_iter % dataset_size
for i, data in enumerate(dataset, start=epoch_iter):
if total_steps % opt.print_freq == print_delta:
iter_start_time = time.time()
total_steps += opt.batchSize
epoch_iter += opt.batchSize
# whether to collect output images
save_fake = total_steps % opt.display_freq == display_delta
############## Forward Pass ######################
losses, generated = model(Variable(data['label']), Variable(data['inst']),
Variable(data['image']), Variable(data['feat']), infer=save_fake)
# sum per device losses
losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses]
loss_dict = dict(zip(model.module.loss_names, losses))
# calculate final loss scalar
loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)
############### Backward Pass ####################
# update generator weights
optimizer_G.zero_grad()
if opt.fp16:
with amp.scale_loss(loss_G, optimizer_G) as scaled_loss:
scaled_loss.backward()
else:
loss_G.backward()
optimizer_G.step()
# update discriminator weights
optimizer_D.zero_grad()
if opt.fp16:
with amp.scale_loss(loss_D, optimizer_D) as scaled_loss:
scaled_loss.backward()
else:
loss_D.backward()
optimizer_D.step()
############## Display results and errors ##########
### print out errors
if total_steps % opt.print_freq == print_delta:
errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
t = (time.time() - iter_start_time) / opt.print_freq
visualizer.print_current_errors(epoch, epoch_iter, errors, t)
visualizer.plot_current_errors(errors, total_steps)
# call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])
### display output images
if save_fake:
visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
('synthesized_image', util.tensor2im(generated.data[0])),
('real_image', util.tensor2im(data['image'][0]))])
visualizer.display_current_results(visuals, epoch, total_steps)
### save latest model
if total_steps % opt.save_latest_freq == save_delta:
print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
model.module.save('latest')
np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
if epoch_iter >= dataset_size:
break
# end of epoch
iter_end_time = time.time()
print('End of epoch %d / %d \t Time Taken: %d sec' %
(epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
### save model for this epoch
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
model.module.save('latest')
model.module.save(epoch)
np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')
### instead of only training the local enhancer, train the entire network after certain iterations
if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
model.module.update_fixed_params()
### linearly decay learning rate after certain iterations
if epoch > opt.niter:
model.module.update_learning_rate()
if __name__=='__main__':
train()
2.训练指令
python train.py --name project1024 --label_nc 0 --no_instance --gpu_ids 0
3.测试程序test.py
import os
from collections import OrderedDict
from torch.autograd import Variable
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer
from util import html
import torch
def test():
opt = TestOptions().parse(save=False)
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
if not opt.engine and not opt.onnx:
model = create_model(opt)
if opt.data_type == 16:
model.half()
elif opt.data_type == 8:
model.type(torch.uint8)
if opt.verbose:
print(model)
else:
from run_engine import run_trt_engine, run_onnx
for i, data in enumerate(dataset):
if i >= opt.how_many:
break
if opt.data_type == 16:
data['label'] = data['label'].half()
data['inst'] = data['inst'].half()
elif opt.data_type == 8:
data['label'] = data['label'].uint8()
data['inst'] = data['inst'].uint8()
if opt.export_onnx:
print("Exporting to ONNX: ", opt.export_onnx)
assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
torch.onnx.export(model, [data['label'], data['inst']],
opt.export_onnx, verbose=True)
exit(0)
minibatch = 1
if opt.engine:
generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
elif opt.onnx:
generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
else:
generated = model.inference(data['label'], data['inst'], data['image'])
visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
('synthesized_image', util.tensor2im(generated.data[0]))])
img_path = data['path']
print('process image... %s' % img_path)
visualizer.save_images(webpage, visuals, img_path)
webpage.save()
if __name__=='__main__':
test()
4.测试指令:
python test.py --name project1024 --ngf 64 --label_nc 0 --no_instance --how_many 464
图像输入尺寸不同时需要修改ngf大小,how_many是测试图像数量,默认是50。
5.测试结果显示
官方提供的测试程序会在web文件夹中生成images文件夹和index.html,在images文件夹中存了对应的测试图像和生成图像,在index.html文件中会将images文件夹中的图像以表格的形式显示,遗憾的是缺少真值图像。测试完成之后可以用如下程序在html文件以label-syn-real的形式显示,类似如下形式显示。
程序Img2Html.py:
import dominate
from dominate.tags import *
import os
import argparse
import glob
parse=argparse.ArgumentParser()
parse.add_argument('--test_B',type=str,default='web/test_B',help='Path to the test_B folder')
parse.add_argument('--ima', dest='ima', type=str, default='web/images', help='Path to the images folder')
parse.add_argument('--outDir',type=str,default='./',help='Path to the output folder')
args=parse.parse_args()
class HTML:
def __init__(self, title, refresh=0):
self.title = title
self.doc = dominate.document(title=title)
if refresh > 0:
with self.doc.head:
meta(http_equiv="refresh", content=str(refresh))
def add_header(self, str):
with self.doc:
h3(str)
def add_table(self, border=1):
self.t = table(border=border, style="table-layout: fixed;")
self.doc.add(self.t)
def add_images(self, ims, txts, links, width=512):
self.add_table()
with self.t:
with tr():
for im, txt, link in zip(ims, txts, links):
with td(style="word-wrap: break-word;", halign="center", valign="top"):
with p():
if txt=='real':
with a(href=os.path.join(args.test_B, link)):
img(style="width:%dpx" % (width), src=os.path.join(args.test_B, im))
else:
with a(href=os.path.join(args.ima, link)):
img(style="width:%dpx" % (width), src=os.path.join(args.ima, im))
br()
p(txt)
def save(self,outDir):
# html_file = '%s/index.html' % self.web_dir
html_file = outDir+'/testResult.html'
f = open(html_file, 'wt')
f.write(self.doc.render())
f.close()
def getJpgFile(pathDir):
if os.path.exists(pathDir)==True:
Jpg = os.path.join(pathDir, '*.jpg')
JpgFile = glob.glob(Jpg)
return JpgFile
else:
print('{0}Jpg图像文件夹不存在'.format(pathDir))
return ''
def Img2HtmlForm(images):
html = HTML('test_html')
html.add_header('test result')
# 获取指定路径下的所有Jpg文件列表
jpgFiles = getJpgFile(images)
if len(jpgFiles) > 0:
for jpgFile in jpgFiles:
ims = []
txts = []
links = []
if 'input' in jpgFile:
jpgFile = os.path.split(jpgFile)[1]
label=jpgFile
ims.append(label)
txts.append('label')
links.append(label)
syn=jpgFile.split('_')[0]+'_synthesized_image'+'.jpg'
ims.append(syn)
txts.append('synthesized')
links.append(syn)
real=jpgFile.split('_')[0]+'.jpg'
ims.append(real)
txts.append('real')
links.append(real)
html.add_images(ims, txts, links)
else:
print('{0}下无Jpg图像'.format(images))
html.save(args.outDir)
if __name__ == '__main__':
Img2HtmlForm(args.ima)
6.训练和测试的其他参数
参考options文件夹下的py文件中的参数设置。