pix2pixHD训练自己的数据集(win10)

一、环境要求

  • Linux or macOS or win10(me win10+anaconda)
  • Python 2 or 3(me python3.6)
  • NVIDIA GPU (11G memory or larger) + CUDA cuDNN(me cuda9.2,12G内存的GPU可以训练1024x512的图像,无法训练或测试2048x1024的图像)

二、环境配置

  • 安装cuda和pytorch
  • 安装dominate:pip install dominate
  • 安装其他Python包,缺什么装什么。

三、源码下载

四、制作自己的数据集

1.图像输入尺寸:

  起初我的图像大小不统一,训练时出现张量不匹配的错误,之后,我的数据集在训练之前统一缩放到了1024x512。预处理resize_or_crop的默认设置为scale_width,loadSize默认设置为1024,也就是在保持宽高比的同时将所有训练图像的宽度缩放为(1024)。如果您想要其他设置,请使用–resize_or_crop选项进行更改。例如,将resize_or_crop设置为scale_width_and_crop,就是首先将图像调整为具有宽度opt.loadSize,然后随机裁剪size (opt.fineSize, opt.fineSize)。–resize_or_crop设置为crop时仅执行随机裁剪。如果您不希望进行任何预处理,请指定none,除了确保图像可以被32整除之外,它不会做任何其他事情。

2.带有instance的数据集:

  带有instance的数据集需要生成标签映射,该标签映射是单通道的,其像素值对应于对象标签(即0,1,…,N-1,其中N是标签的数量)。这是因为需要从标签图生成一个热向量。具体数据集可以参考代码中datasets/cityscapes示例数据集制作。

3.不带instance的数据集:

  直接使用RGB图像作为输入的内容即可,数据集应包括train_A,train_B,test_A和test_B四个文件夹,训练从A到B的一种映射关系,A和B中的图像文件名要一一对应,如train_A标签图000000.jpg对应真值图train_B中的000000.jpg.

五、训练和测试(me ,no_instance,win10)

  Windows系统没有sh文件,所以直接用Python训练测试。
  官方程序在train.py和test.py文件的main函数外开启了多线程,而windows下Python开启多线程必须在内置main函数中开启,因此train.py和test.py需要修改,否则报如下错误:

RuntimeError:
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

1.训练程序train.py

import time
import os
import numpy as np
import torch
from torch.autograd import Variable
from collections import OrderedDict
from subprocess import call
import fractions
def lcm(a,b): return abs(a * b)/fractions.gcd(a,b) if a and b else 0

from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer

def train():
    opt = TrainOptions().parse()
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path, delimiter=',', dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
    else:
        start_epoch, epoch_iter = 1, 0

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.niter = 1
        opt.niter_decay = 0
        opt.max_dataset_size = 10

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    model = create_model(opt)
    visualizer = Visualizer(opt)
    if opt.fp16:
        from apex import amp
        model, [optimizer_G, optimizer_D] = amp.initialize(model, [model.optimizer_G, model.optimizer_D],
                                                           opt_level='O1')
        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
    else:
        optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D

    total_steps = (start_epoch - 1) * dataset_size + epoch_iter

    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == print_delta:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            losses, generated = model(Variable(data['label']), Variable(data['inst']),
                                      Variable(data['image']), Variable(data['feat']), infer=save_fake)

            # sum per device losses
            losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses]
            loss_dict = dict(zip(model.module.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            optimizer_G.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_G, optimizer_G) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_G.backward()
            optimizer_G.step()

            # update discriminator weights
            optimizer_D.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_D, optimizer_D) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_D.backward()
            optimizer_D.step()

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
                t = (time.time() - iter_start_time) / opt.print_freq
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)
                # call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

            ### display output images
            if save_fake:
                visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                       ('synthesized_image', util.tensor2im(generated.data[0])),
                                       ('real_image', util.tensor2im(data['image'][0]))])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
                model.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        iter_end_time = time.time()
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
            model.module.save('latest')
            model.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.module.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.module.update_learning_rate()

if __name__=='__main__':
    train()

2.训练指令

python train.py --name project1024 --label_nc 0 --no_instance --gpu_ids 0

3.测试程序test.py

import os
from collections import OrderedDict
from torch.autograd import Variable
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer
from util import html
import torch

def test():
    opt = TestOptions().parse(save=False)
    opt.nThreads = 1  # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    visualizer = Visualizer(opt)
    # create website
    web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))

    # test
    if not opt.engine and not opt.onnx:
        model = create_model(opt)
        if opt.data_type == 16:
            model.half()
        elif opt.data_type == 8:
            model.type(torch.uint8)

        if opt.verbose:
            print(model)
    else:
        from run_engine import run_trt_engine, run_onnx

    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break
        if opt.data_type == 16:
            data['label'] = data['label'].half()
            data['inst'] = data['inst'].half()
        elif opt.data_type == 8:
            data['label'] = data['label'].uint8()
            data['inst'] = data['inst'].uint8()
        if opt.export_onnx:
            print("Exporting to ONNX: ", opt.export_onnx)
            assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
            torch.onnx.export(model, [data['label'], data['inst']],
                              opt.export_onnx, verbose=True)
            exit(0)
        minibatch = 1
        if opt.engine:
            generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
        elif opt.onnx:
            generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
        else:
            generated = model.inference(data['label'], data['inst'], data['image'])

        visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                               ('synthesized_image', util.tensor2im(generated.data[0]))])
        img_path = data['path']
        print('process image... %s' % img_path)
        visualizer.save_images(webpage, visuals, img_path)

    webpage.save()

if __name__=='__main__':
    test()

4.测试指令:

python test.py --name project1024 --ngf 64 --label_nc 0 --no_instance --how_many 464

  图像输入尺寸不同时需要修改ngf大小,how_many是测试图像数量,默认是50。

5.测试结果显示

  官方提供的测试程序会在web文件夹中生成images文件夹和index.html,在images文件夹中存了对应的测试图像和生成图像,在index.html文件中会将images文件夹中的图像以表格的形式显示,遗憾的是缺少真值图像。测试完成之后可以用如下程序在html文件以label-syn-real的形式显示,类似如下形式显示。
在这里插入图片描述
  程序Img2Html.py:

import dominate
from dominate.tags import *
import os
import argparse
import glob

parse=argparse.ArgumentParser()
parse.add_argument('--test_B',type=str,default='web/test_B',help='Path to the test_B folder')
parse.add_argument('--ima', dest='ima', type=str, default='web/images', help='Path to the images folder')
parse.add_argument('--outDir',type=str,default='./',help='Path to the output folder')
args=parse.parse_args()

class HTML:
    def __init__(self, title, refresh=0):
        self.title = title
        self.doc = dominate.document(title=title)
        if refresh > 0:
            with self.doc.head:
                meta(http_equiv="refresh", content=str(refresh))

    def add_header(self, str):
        with self.doc:
            h3(str)

    def add_table(self, border=1):
        self.t = table(border=border, style="table-layout: fixed;")
        self.doc.add(self.t)

    def add_images(self, ims, txts, links, width=512):
        self.add_table()
        with self.t:
            with tr():
                for im, txt, link in zip(ims, txts, links):
                    with td(style="word-wrap: break-word;", halign="center", valign="top"):
                        with p():
                            if txt=='real':
                                with a(href=os.path.join(args.test_B, link)):
                                    img(style="width:%dpx" % (width), src=os.path.join(args.test_B, im))
                            else:
                                with a(href=os.path.join(args.ima, link)):
                                    img(style="width:%dpx" % (width), src=os.path.join(args.ima, im))
                            br()
                            p(txt)

    def save(self,outDir):
        # html_file = '%s/index.html' % self.web_dir
        html_file = outDir+'/testResult.html'
        f = open(html_file, 'wt')
        f.write(self.doc.render())
        f.close()


def getJpgFile(pathDir):
    if os.path.exists(pathDir)==True:
        Jpg = os.path.join(pathDir, '*.jpg')
        JpgFile = glob.glob(Jpg)
        return JpgFile
    else:
        print('{0}Jpg图像文件夹不存在'.format(pathDir))
        return ''
def Img2HtmlForm(images):
    html = HTML('test_html')
    html.add_header('test result')
    # 获取指定路径下的所有Jpg文件列表
    jpgFiles = getJpgFile(images)
    if len(jpgFiles) > 0:
        for jpgFile in jpgFiles:
            ims = []
            txts = []
            links = []
            if 'input' in jpgFile:
                jpgFile = os.path.split(jpgFile)[1]
                label=jpgFile
                ims.append(label)
                txts.append('label')
                links.append(label)
                syn=jpgFile.split('_')[0]+'_synthesized_image'+'.jpg'
                ims.append(syn)
                txts.append('synthesized')
                links.append(syn)
                real=jpgFile.split('_')[0]+'.jpg'
                ims.append(real)
                txts.append('real')
                links.append(real)
                html.add_images(ims, txts, links)
    else:
        print('{0}下无Jpg图像'.format(images))
    html.save(args.outDir)


if __name__ == '__main__':
    Img2HtmlForm(args.ima)

6.训练和测试的其他参数

  参考options文件夹下的py文件中的参数设置。

评论 39
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值