我们在上一期的代码中注意到了,这样一行代码:
# 建立模型对象
model = create_model(opt)
create_model 是一个高度封装的用法或类,我们利用Pycharm的跳转可以找到它的源文件:
import torch
def create_model(opt):
if opt.model == 'pix2pixHD':
from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
if opt.isTrain: # 如果进行训练
model = Pix2PixHDModel()
else:
model = InferenceModel()
else: # (这个是调用了pix2pixhd的ui界面,一般不使用)
from .ui_model import UIModel
model = UIModel()
model.initialize(opt) # 载入options初始化模型,主要是完成了所有相关模型的初始化和加载(包括优化器等等)
if opt.verbose: # 默认是false,我们可以设置成true,帮助辨别(因为我们可能在不同的数据集上进行训练)
print("model [%s] was created" % (model.name()))
if opt.isTrain and len(opt.gpu_ids) and not opt.fp16: # 如果在gpu上进行非混合精度训练就使用分布式训练
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
return model
我们用到的便是Pix2PixHDModel,我们继续跳转,进入到Pix2PixHD_model.py中,发现:
# 继承BaseModel类
class Pix2PixHDModel(BaseModel):
因此我们首先便去查看base_model.py中定义的类和用法,以下为其详细注释:
import os
import torch
import sys
class BaseModel(torch.nn.Module):
def name(self):
return 'BaseModel'
def initialize(self, opt):
# 载入参数optins
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
# 如果有显卡就调用显卡使用cuda.floatTensor
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # 权重文件保存位置
def set_input(self, input):
self.input = input
def forward(self):
pass
# used in test time, no backprop
def test(self):
pass
def get_image_paths(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
return self.input
def get_current_errors(self):
return {}
def save(self, label):
pass
# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids):
# 保存权重文件
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if len(gpu_ids) and torch.cuda.is_available():
network.cuda()
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label , save_dir=''):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
# 定义了load的权重文件名字,例如:200_net_G.pth
if not save_dir:
# 如果save_dir是空,就使用optins来定义一个保存文件夹save_dir,如果非空就载入已有权重文件地址
save_dir = self.save_dir
save_path = os.path.join(save_dir, save_filename)
# 明确权重文件保存的位置
if not os.path.isfile(save_path): # 检测save_path是否存在
print('%s not exists yet!' % save_path)
if network_label == 'G':
raise('Generator must exist!')
# 如果模型为Generator,那么必须有权重文件可加载
# 而对于鉴别器,显然就可以不存在,如果不存在鉴别器就不会加载,也不会报错(raise)
else:
#network.load_state_dict(torch.load(save_path))
try:
network.load_state_dict(torch.load(save_path)) # 尝试进行载入权重文件
except: # 载入失败,下面的代码都是处理错误并打印错误原因的
pretrained_dict = torch.load(save_path)
model_dict = network.state_dict()
try: # 载入的预训练模型中层数太多
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
network.load_state_dict(pretrained_dict)
if self.opt.verbose:
print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
except: # 载入的预训练模型层数太少,并打印出预训练模型中缺少的层数
print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
for k, v in pretrained_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v
if sys.version_info >= (3,0):
not_initialized = set()
else:
from sets import Set
not_initialized = Set()
for k, v in model_dict.items():
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
not_initialized.add(k.split('.')[0])
print(sorted(not_initialized))
network.load_state_dict(model_dict)
def update_learning_rate():
pass