Pix2PixHD代码小白注释(2)—— BaseModel.py

上一期:
Pix2PixHD代码小白注释(1)——train.pyicon-default.png?t=N7T8https://blog.csdn.net/qq_73991479/article/details/134757142?spm=1001.2014.3001.5501

        我们在上一期的代码中注意到了,这样一行代码:

# 建立模型对象
model = create_model(opt)

        create_model 是一个高度封装的用法或类,我们利用Pycharm的跳转可以找到它的源文件:

import torch

def create_model(opt):
    if opt.model == 'pix2pixHD':
        from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
        if opt.isTrain: # 如果进行训练
            model = Pix2PixHDModel()
        else:
            model = InferenceModel()
    else: # (这个是调用了pix2pixhd的ui界面,一般不使用)
    	from .ui_model import UIModel
    	model = UIModel()
    model.initialize(opt) # 载入options初始化模型,主要是完成了所有相关模型的初始化和加载(包括优化器等等)
    if opt.verbose: # 默认是false,我们可以设置成true,帮助辨别(因为我们可能在不同的数据集上进行训练)
        print("model [%s] was created" % (model.name()))

    if opt.isTrain and len(opt.gpu_ids) and not opt.fp16: # 如果在gpu上进行非混合精度训练就使用分布式训练
        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)

    return model

        我们用到的便是Pix2PixHDModel,我们继续跳转,进入到Pix2PixHD_model.py中,发现:

# 继承BaseModel类
class Pix2PixHDModel(BaseModel):

        因此我们首先便去查看base_model.py中定义的类和用法,以下为其详细注释:

import os
import torch
import sys

class BaseModel(torch.nn.Module):
    def name(self):
        return 'BaseModel'

    def initialize(self, opt):
        # 载入参数optins
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 
        # 如果有显卡就调用显卡使用cuda.floatTensor
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # 权重文件保存位置

    def set_input(self, input):
        self.input = input

    def forward(self):
        pass

    # used in test time, no backprop
    def test(self):
        pass

    def get_image_paths(self):
        pass

    def optimize_parameters(self):
        pass

    def get_current_visuals(self):
        return self.input

    def get_current_errors(self):
        return {}

    def save(self, label):
        pass

    # helper saving function that can be used by subclasses
    def save_network(self, network, network_label, epoch_label, gpu_ids):
        # 保存权重文件
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if len(gpu_ids) and torch.cuda.is_available():
            network.cuda()

    # helper loading function that can be used by subclasses
    def load_network(self, network, network_label, epoch_label , save_dir=''):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 
        # 定义了load的权重文件名字,例如:200_net_G.pth
        if not save_dir: 
        # 如果save_dir是空,就使用optins来定义一个保存文件夹save_dir,如果非空就载入已有权重文件地址
            save_dir = self.save_dir
        save_path = os.path.join(save_dir, save_filename)  
        # 明确权重文件保存的位置
        if not os.path.isfile(save_path): # 检测save_path是否存在
            print('%s not exists yet!' % save_path)
            if network_label == 'G':
                raise('Generator must exist!')
                # 如果模型为Generator,那么必须有权重文件可加载
                # 而对于鉴别器,显然就可以不存在,如果不存在鉴别器就不会加载,也不会报错(raise)
        else:
            #network.load_state_dict(torch.load(save_path))
            try:
                network.load_state_dict(torch.load(save_path)) # 尝试进行载入权重文件
            except:   # 载入失败,下面的代码都是处理错误并打印错误原因的
                pretrained_dict = torch.load(save_path)                
                model_dict = network.state_dict()
                try: # 载入的预训练模型中层数太多
                    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}                    
                    network.load_state_dict(pretrained_dict)
                    if self.opt.verbose:
                        print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
                except: # 载入的预训练模型层数太少,并打印出预训练模型中缺少的层数
                    print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
                    for k, v in pretrained_dict.items():                      
                        if v.size() == model_dict[k].size():
                            model_dict[k] = v

                    if sys.version_info >= (3,0):
                        not_initialized = set()
                    else:
                        from sets import Set
                        not_initialized = Set()                    

                    for k, v in model_dict.items():
                        if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
                            not_initialized.add(k.split('.')[0])
                    
                    print(sorted(not_initialized))
                    network.load_state_dict(model_dict)

    def update_learning_rate():
        pass

  • 9
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

咖啡百怪

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值