四、pix2pixHD代码解析(models搭建)

pix2pixHD代码解析

一、pix2pixHD代码解析(train.py + test.py)
二、pix2pixHD代码解析(options设置)
三、pix2pixHD代码解析(dataset处理)
四、pix2pixHD代码解析(models搭建)

四、pix2pixHD代码解析(models搭建)

models.py

import torch


# 创建模型,并返回模型
def create_model(opt):
    if opt.model == 'pix2pixHD':                                        # 选择pix2pixHD model
        from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
        if opt.isTrain:                                                 # 若是训练,则为True
            model = Pix2PixHDModel()
        else:                                                           # 否则,若仅仅是前向传播用来演示,则为False
            model = InferenceModel()
    else:                                                               # 选择 UIModel model
    	from .ui_model import UIModel
    	model = UIModel()
    model.initialize(opt)                                               # 模型初始化参数
    if opt.verbose:                                                     # 默认为false,表示之前并无模型保存
        print("model [%s] was created" % (model.name()))                # 打印label2city模型被创建

    if opt.isTrain and len(opt.gpu_ids) and not opt.fp16:
        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)    # 多GPU训练

    return model

pix2pixHD.py

import numpy as np
import torch
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks


################################################################################################
# 看到Pix2PixHDModel的类。这个类的内容非常多,有搭建模型,定义优化器和损失函数,导入模型等操作。
################################################################################################
class Pix2PixHDModel(BaseModel):                                                                                        # 继承自BaseModel类,里面主要有save和load模型函数
    def name(self):
        return 'Pix2PixHDModel'

    # loss滤波器:其中g_gan、d_real、d_fake三个loss值是肯定返回的
    # 至于g_gan_feat,g_vgg两个loss值根据train_options的opt.no_ganFeat_loss, not opt.no_vgg_loss而定
    # 备注:这个函数只是一个滤波器,不仅可以滤掉loss值,也可以滤掉loss name,主要看是谁在调用,输入什么,就可以滤掉什么。
    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
        flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)

        def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
            ### zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
            # 如果各个迭代器的元素个数不一致,则返回列表长度与最短的对象相同,利用 * 号操作符,可以将元组解压为列表。
            # >>>a = [1,2,3]
            # >>> b = [4,5,6]
            # >>> c = [4,5,6,7,8]
            # >>> zipped = zip(a,b)     # 打包为元组的列表
            # [(1, 4), (2, 5), (3, 6)]
            # >>> zip(a,c)              # 元素个数与最短的列表一致
            # [(1, 4), (2, 5), (3, 6)]
            # >>> zip(*zipped)          # 与 zip 相反,*zipped 可理解为解压,返回二维矩阵式
            # [(1, 2, 3), (4, 5, 6)]
            return [l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, d_real, d_fake), flags) if f]                        # 当f为True时,返回对应的l,其中l表示loss值
        return loss_filter                                                                                              # 最后返回的是激活的loss值,False的loss值并不记录在内

    # 在initialize函数里面看看对Pix2PixHDModel的一些设置。
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        self.use_features = opt.instance_feat or opt.label_feat
        self.gen_features = self.use_features and not self.opt.load_features
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc                                                  # 因为label_n=0,因此赋值为3

        ##### define networks
        # Generator network
        netG_input_nc = input_nc                                                                                        # 输入层数
        if not opt.no_instance:                                                                                         # 如果有实例标签,则通道加1
            netG_input_nc += 1
        if self.use_features:
            netG_input_nc += opt.feat_num                  
        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,                                  # 创建全局生成网络
                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 
                                      opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)        

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = input_nc + opt.output_nc
            if not opt.no_instance:
                netD_input_nc += 1
            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 
                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)                     # train_options里设置了opt.num_D=2

        ### Encoder network
        if self.gen_features:
            self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', 
                                          opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)  
        if self.opt.verbose:
                print('---------- Networks initialized -------------')

        # load networks
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)            
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)  
            if self.gen_features:
                self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)              

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
            
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)                        # 训练G和D的损失函数定义
            self.criterionFeat = torch.nn.L1Loss()                                                                      # feature matching损失项的定义,使用的是L1 loss。
            if not opt.no_vgg_loss:             
                self.criterionVGG = networks.VGGLoss(self.gpu_ids)                                                      # percetual loss的定义。这是可选项,对最终结果也有帮助。
                
        
            # Names so we can breakout loss
            self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake')                         # 利用loss滤波器返回有用的loss名字

            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:                
                import sys
                if sys.version_info >= (3,0):
                    finetune_list = set()
                else:
                    from sets import Set
                    finetune_list = Set()

                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():       
                    if key.startswith('model' + str(opt.n_local_enhancers)):                    
                        params += [value]
                        finetune_list.add(key.split('.')[0])  
                print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
                print('The layers that are finetuned are ', sorted(finetune_list))                         
            else:
                params = list(self.netG.parameters())
            if self.gen_features:              
                params += list(self.netE.parameters())         
            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))                            

            # optimizer D                        
            params = list(self.netD.parameters())    
            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))

    def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):             
        # 1、label_map 数据类型转化
        if self.opt.label_nc == 0:                                                                                      # 如果label通道为0,那么直接转为cuda张量
            input_label = l
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值