Pix2PixHD代码小白解读(3)——Pix2PixHD_model.py

上两期:

      Pix2PixHD代码小白注释(1)——train.pyicon-default.png?t=N7T8https://blog.csdn.net/qq_73991479/article/details/134757142?spm=1001.2014.3001.5501        
Pix2PixHD代码小白注释(2)—— BaseModel.pyicon-default.png?t=N7T8https://blog.csdn.net/qq_73991479/article/details/134762184?spm=1001.2014.3001.5502 

        在train.py中,我们注意到以下两行代码:

# 建立模型对象
model = create_model(opt)
# 完成正向传播
losses, generated = model(Variable(data['label']), Variable(data['inst']), 
            Variable(data['image']), Variable(data['feat']), infer=save_fake)

        这两行代码显然是作者高度封装的类和用法,前者我们在上一期了解到是在建立一个模型类,并且使用的是Pix2PixHDModel,其中Pix2PixHDModel类继承了BaseModel的用法,那么本期就深入了解Pix2PixHDModel

import numpy as np
import torch
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks

class Pix2PixHDModel(BaseModel):
    def name(self):
        return 'Pix2PixHDModel'
    
    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
        flags = (True, use_gan_feat_loss, use_vgg_loss, True, True) # 记录了每个损失是否use的元组
        def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
            return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f]
        # zip建立了flag和loss名组成的元组,返回一个只包含我们use的loss名的列表
        return loss_filter # 返回函数用法
    
    def initialize(self, opt):
        BaseModel.initialize(self, opt) # basemodel的初始化,载入了权重文件保存位置、显卡数目等等
        if opt.resize_or_crop != 'none' or not opt.isTrain:
            # 如果我们输入了全分辨率的图像而不进行任何的裁剪
            # 容易造成内存的溢出,我们这时候调用了cudnn的基准模式,优化卷积操作
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        self.use_features = opt.instance_feat or opt.label_feat
        # 是否使用instance-wise的实例编码特征或标签编码特征(该选项主要是用于控制输出或实现输出的多样化,见论文)
        self.gen_features = self.use_features and not self.opt.load_features
        # 使用instance-wise的编码特征并且不加载已经预处理的features map为true
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
        # input_nc=label_nc除非input_nc为0

        ##### define networks        
        # Generator network
        netG_input_nc = input_nc        
        if not opt.no_instance: 
        # 如果使用boundary map就给输入加入一个维度,这是与论文所说对应的
            netG_input_nc += 1
        if self.use_features:
            netG_input_nc += opt.feat_num  # 如果使用编码特征图就再加入编码特征图的维度
        # 建立生成器对象
        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 
                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 
                                      opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)        

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan 
            # 如果no_lsgan,就表示使用sigmoid激活函数,否则不使用
            netD_input_nc = input_nc + opt.output_nc 
            # 将生成器输出(或真实生成图像)和标签维度相加,输入鉴别器
            if not opt.no_instance: # 如果使用boundary map就再加入一个维度
                netD_input_nc += 1
            # 建立鉴别器对象
            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 
                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)

        ### Encoder network
        # 如果使用了实例编码特征图并且不使用已经预处理的特征图,就建立一个feature encoder,其模型结构和生成器一样
        if self.gen_features:          
            self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', 
                                          opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)  
        if self.opt.verbose:
                print('---------- Networks initialized -------------')

        # load networks
        # 我们在这里定义了一个network加载器,该加载器可以用在test中,continue_train中以及载入预训练模型进行训练中
        if not self.isTrain or opt.continue_train or opt.load_pretrain:
            # 如果不训练(test),或进行continue_train或加入预训练模型,就需要load networks
            pretrained_path = '' if not self.isTrain else opt.load_pretrain # 如果不训练(测试)就为空,否则就是预训练模型
            # load_network是作者自定义的函数,在base_model中有定义,这里我们进行简单介绍
            # 函数的功能主要是将pretrained_path(权重文件或预训练模型)加载到netG(或D,E)中
            # 若pretrained_path为空,就不加载(但如果为加载generator生成器模型,就会报错,
            # 因为在使用该函数时,我们规定了not isTrain或者continue——train,这两个对生成器都是必须的)
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)            
            if self.isTrain: # 如果进行训练,但是继续训练或者加入预训练模型,加载鉴别器
                self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
            if self.gen_features: # 如果使用feature encoder,就加载模型
                self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)              

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions

            self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
            # 返回一个用法,该用法输入各个loss的使用与否,输出一个只包含使用的loss的列表名
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            # 建立了一个GANLoss对象
            # 输入的opt包括是否使用lsgan,以及使用的Tensor类型(取决于使用显卡还是cpu)
            self.criterionFeat = torch.nn.L1Loss()
            # 建立了一个计算L1 loss的用法
            if not opt.no_vgg_loss:
                # 建立计算感知损失的类
                self.criterionVGG = networks.VGGLoss(self.gpu_ids)
                
        
            # Names so we can breakout loss
            self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake')

            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:
                # 是否只使用outmost的局部生成器(local hencer)
                import sys
                if sys.version_info >= (3,0): # 判断python的版本号是否大于3.0
                    finetune_list = set() # 如果大于3.0,就建立一个集合,接受列表作为参数
                else:
                    from sets import Set
                    finetune_list = Set()

                params_dict = dict(self.netG.named_parameters()) # 建立一个字典,字典键值为生成器网络的所有参数,键为各参数对应的名字
                params = [] # 建立一个列表
                for key, value in params_dict.items():  # 遍历字典
                    if key.startswith('model' + str(opt.n_local_enhancers)):
                        # 如果遍历中遇到键值为model+n.local_enhancers
                        params += [value] # params就在列表中加入键值,此时记录的全部是local_hancers的相关参数
                        finetune_list.add(key.split('.')[0])  # 在.处将键值打断并生成列表输出[0]
                print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
                print('The layers that are finetuned are ', sorted(finetune_list))                         
            else: # 如果不只使用局部生成器,就直接生成params
                params = list(self.netG.parameters())
            if self.gen_features:              
                params += list(self.netE.parameters())
            # 完成生成器优化器的定义
            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))                            

            # optimizer D                        
            params = list(self.netD.parameters())    
            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))

    def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
        # 编码部分,既包含boundary map的编码,也包含feature的编码
        if self.opt.label_nc == 0: # 如果label_nc = 0,这也就表示我们的boundary map中没有对象
            input_label = label_map.data.cuda()
        else:
            # create one-hot vector for label map 
            size = label_map.size()
            oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
            input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
            input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
            if self.opt.data_type == 16:
                input_label = input_label.half()

        # 使用边缘图结合label进行输入时
        # get edges from instance map
        if not self.opt.no_instance:
            inst_map = inst_map.data.cuda()
            edge_map = self.get_edges(inst_map)
            input_label = torch.cat((input_label, edge_map), dim=1)
            # 在维度1上进行拼接,由于pytorch中tensor为NCHW,那么也就是C加1,这与我们之前输入时给input_nc+1是一致的
        input_label = Variable(input_label, volatile=infer) # 将tensor转换为variable

        # real images for training
        if real_image is not None:
            real_image = Variable(real_image.data.cuda())

        # instance map for feature encoding
        if self.use_features:
            # 如果使用instance-wise
            # get precomputed feature maps
            if self.opt.load_features:
            # 如果使用已经预处理的特征图
                feat_map = Variable(feat_map.data.cuda())
            if self.opt.label_feat:
            # 如果使用label编码特征图
                inst_map = label_map.cuda()

        return input_label, inst_map, real_image, feat_map
    # 返回输入标签图,label编码图,真值图以及预处理的特征图(没有就返回None)

    def discriminate(self, input_label, test_image, use_pool=False):
        input_concat = torch.cat((input_label, test_image.detach()), dim=1)
        # 将输入标签和测试图拼接一起输入网络
        if use_pool:
            fake_query = self.fake_pool.query(input_concat)
            return self.netD.forward(fake_query)
        else:
            # 正向传播返回结果
            return self.netD.forward(input_concat)

    def forward(self, label, inst, image, feat, infer=False):
        # Encode Inputs
        input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)  
        # 首先定义了encoder网络的输入(同时定义了label_input)

        # Fake Generation
        if self.use_features:
            if not self.opt.load_features:
                # 如果使用encoder并不使用预处理特征图,就把真值图和特征图输入encoder,返回encoder得到的特征图
                feat_map = self.netE.forward(real_image, inst_map)                     
            input_concat = torch.cat((input_label, feat_map), dim=1)
            # 将输入label(此时已经包含了boundary map)与特征图拼接(此时为真正的输入标签图)
        else:
            input_concat = input_label
        fake_image = self.netG.forward(input_concat) # 完成生成器正向传播

        # Fake Detection and Loss
        pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
        # 计算fake_loss
        loss_D_fake = self.criterionGAN(pred_fake_pool, False)        

        # Real Detection and Loss        
        pred_real = self.discriminate(input_label, real_image)
        # 计算real_loss
        loss_D_real = self.criterionGAN(pred_real, True)

        # GAN loss (Fake Passability Loss)
        # 计算生成器GAN损失
        pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))        
        loss_G_GAN = self.criterionGAN(pred_fake, True)               
        
        # GAN feature matching loss
        loss_G_GAN_Feat = 0
        # 计算特征匹配损失
        if not self.opt.no_ganFeat_loss:
            feat_weights = 4.0 / (self.opt.n_layers_D + 1)
            D_weights = 1.0 / self.opt.num_D
            for i in range(self.opt.num_D):
                for j in range(len(pred_fake[i])-1):
                    loss_G_GAN_Feat += D_weights * feat_weights * \
                        self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
                   
        # VGG feature matching loss
        loss_G_VGG = 0
        # 计算感知损失
        if not self.opt.no_vgg_loss:
            loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
        
        # Only return the fake_B image if necessary to save BW
        # 最后返回所有损失
        return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ]

    def inference(self, label, inst, image=None):
        # 生成图片(test中实现图像翻译)的用法
        # Encode Inputs        
        image = Variable(image) if image is not None else None
        # 是否有真值图
        input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)
        # 定义了encoder网络的输入(同时定义了label_input)

        # Fake Generation
        # 与上面一样完成输入图片的拼接
        if self.use_features:
            if self.opt.use_encoded_image:
                # encode the real image to get feature map
                feat_map = self.netE.forward(real_image, inst_map)
            else:
                # sample clusters from precomputed features             
                feat_map = self.sample_features(inst_map)
            input_concat = torch.cat((input_label, feat_map), dim=1)                        
        else:
            input_concat = input_label        

        # 正向传播
        if torch.__version__.startswith('0.4'):
            with torch.no_grad():
                fake_image = self.netG.forward(input_concat)
        else:
            fake_image = self.netG.forward(input_concat)
        return fake_image

    def sample_features(self, inst):
        # 读取预处理图片的特征生成特征图
        # read precomputed feature clusters 
        cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)        
        features_clustered = np.load(cluster_path, encoding='latin1').item()

        # randomly sample from the feature clusters
        inst_np = inst.cpu().numpy().astype(int)                                      
        feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])
        for i in np.unique(inst_np):    
            label = i if i < 1000 else i//1000
            if label in features_clustered:
                feat = features_clustered[label]
                cluster_idx = np.random.randint(0, feat.shape[0]) 
                                            
                idx = (inst == int(i)).nonzero()
                for k in range(self.opt.feat_num):                                    
                    feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
        if self.opt.data_type==16:
            feat_map = feat_map.half()
        return feat_map

    def encode_features(self, image, inst):
        # 计算特征值
        image = Variable(image.cuda(), volatile=True)
        # 将图像放在GPU上,并不计算梯度
        feat_num = self.opt.feat_num
        h, w = inst.size()[2], inst.size()[3]
        # 设置获取特征的数量和特征图的大小
        block_num = 32
        feat_map = self.netE.forward(image, inst.cuda())
        # 向前传播,生成特征图
        inst_np = inst.cpu().numpy().astype(int)
        # 将inst图转换为numpy类型
        feature = {}
        for i in range(self.opt.label_nc):
            feature[i] = np.zeros((0, feat_num+1))
            # 循环,让feature集合中的每一个元素都是形状为(0,feat_num+1)的列表
        for i in np.unique(inst_np):
            # 只循环inst数组中不同的项
            label = i if i < 1000 else i//1000
            idx = (inst == int(i)).nonzero()
            num = idx.size()[0]
            idx = idx[num//2,:]
            val = np.zeros((1, feat_num+1))
            # 储存特征值
            for k in range(feat_num):
                val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]            
            val[0, feat_num] = float(num) / (h * w // block_num)
            feature[label] = np.append(feature[label], val, axis=0)
        return feature

    def get_edges(self, t):
        # 获取边缘图boundary map
        edge = torch.cuda.ByteTensor(t.size()).zero_()
        edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
        edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
        edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
        edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
        if self.opt.data_type==16:
            return edge.half()
        else:
            return edge.float()

    def save(self, which_epoch):
        # 保存权重文件的用法
        self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
        self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
        if self.gen_features:
            self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)

    def update_fixed_params(self):
        # after fixing the global generator for a number of iterations, also start finetuning it
        params = list(self.netG.parameters())
        if self.gen_features:
            params += list(self.netE.parameters())           
        self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
        if self.opt.verbose:
            print('------------ Now also finetuning global generator -----------')

    def update_learning_rate(self):
        # 更新学习率
        lrd = self.opt.lr / self.opt.niter_decay # 确认每次学习率降低的多少
        lr = self.old_lr - lrd
        # 遍历,每一层的学习率都降低
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        if self.opt.verbose:
            print('update learning rate: %f -> %f' % (self.old_lr, lr))
        # 更新数据
        self.old_lr = lr

class InferenceModel(Pix2PixHDModel):
    def forward(self, inp):
        label, inst = inp
        return self.inference(label, inst)

        

  • 13
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

咖啡百怪

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值