上两期:
Pix2PixHD代码小白注释(1)——train.pyhttps://blog.csdn.net/qq_73991479/article/details/134757142?spm=1001.2014.3001.5501
Pix2PixHD代码小白注释(2)—— BaseModel.pyhttps://blog.csdn.net/qq_73991479/article/details/134762184?spm=1001.2014.3001.5502
在train.py中,我们注意到以下两行代码:
# 建立模型对象
model = create_model(opt)
# 完成正向传播
losses, generated = model(Variable(data['label']), Variable(data['inst']),
Variable(data['image']), Variable(data['feat']), infer=save_fake)
这两行代码显然是作者高度封装的类和用法,前者我们在上一期了解到是在建立一个模型类,并且使用的是Pix2PixHDModel,其中Pix2PixHDModel类继承了BaseModel的用法,那么本期就深入了解Pix2PixHDModel。
import numpy as np
import torch
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
class Pix2PixHDModel(BaseModel):
def name(self):
return 'Pix2PixHDModel'
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
flags = (True, use_gan_feat_loss, use_vgg_loss, True, True) # 记录了每个损失是否use的元组
def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f]
# zip建立了flag和loss名组成的元组,返回一个只包含我们use的loss名的列表
return loss_filter # 返回函数用法
def initialize(self, opt):
BaseModel.initialize(self, opt) # basemodel的初始化,载入了权重文件保存位置、显卡数目等等
if opt.resize_or_crop != 'none' or not opt.isTrain:
# 如果我们输入了全分辨率的图像而不进行任何的裁剪
# 容易造成内存的溢出,我们这时候调用了cudnn的基准模式,优化卷积操作
torch.backends.cudnn.benchmark = True
self.isTrain = opt.isTrain
self.use_features = opt.instance_feat or opt.label_feat
# 是否使用instance-wise的实例编码特征或标签编码特征(该选项主要是用于控制输出或实现输出的多样化,见论文)
self.gen_features = self.use_features and not self.opt.load_features
# 使用instance-wise的编码特征并且不加载已经预处理的features map为true
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
# input_nc=label_nc除非input_nc为0
##### define networks
# Generator network
netG_input_nc = input_nc
if not opt.no_instance:
# 如果使用boundary map就给输入加入一个维度,这是与论文所说对应的
netG_input_nc += 1
if self.use_features:
netG_input_nc += opt.feat_num # 如果使用编码特征图就再加入编码特征图的维度
# 建立生成器对象
self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
# Discriminator network
if self.isTrain:
use_sigmoid = opt.no_lsgan
# 如果no_lsgan,就表示使用sigmoid激活函数,否则不使用
netD_input_nc = input_nc + opt.output_nc
# 将生成器输出(或真实生成图像)和标签维度相加,输入鉴别器
if not opt.no_instance: # 如果使用boundary map就再加入一个维度
netD_input_nc += 1
# 建立鉴别器对象
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
### Encoder network
# 如果使用了实例编码特征图并且不使用已经预处理的特征图,就建立一个feature encoder,其模型结构和生成器一样
if self.gen_features:
self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)
if self.opt.verbose:
print('---------- Networks initialized -------------')
# load networks
# 我们在这里定义了一个network加载器,该加载器可以用在test中,continue_train中以及载入预训练模型进行训练中
if not self.isTrain or opt.continue_train or opt.load_pretrain:
# 如果不训练(test),或进行continue_train或加入预训练模型,就需要load networks
pretrained_path = '' if not self.isTrain else opt.load_pretrain # 如果不训练(测试)就为空,否则就是预训练模型
# load_network是作者自定义的函数,在base_model中有定义,这里我们进行简单介绍
# 函数的功能主要是将pretrained_path(权重文件或预训练模型)加载到netG(或D,E)中
# 若pretrained_path为空,就不加载(但如果为加载generator生成器模型,就会报错,
# 因为在使用该函数时,我们规定了not isTrain或者continue——train,这两个对生成器都是必须的)
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
if self.isTrain: # 如果进行训练,但是继续训练或者加入预训练模型,加载鉴别器
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
if self.gen_features: # 如果使用feature encoder,就加载模型
self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)
# set loss functions and optimizers
if self.isTrain:
if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
self.fake_pool = ImagePool(opt.pool_size)
self.old_lr = opt.lr
# define loss functions
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
# 返回一个用法,该用法输入各个loss的使用与否,输出一个只包含使用的loss的列表名
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
# 建立了一个GANLoss对象
# 输入的opt包括是否使用lsgan,以及使用的Tensor类型(取决于使用显卡还是cpu)
self.criterionFeat = torch.nn.L1Loss()
# 建立了一个计算L1 loss的用法
if not opt.no_vgg_loss:
# 建立计算感知损失的类
self.criterionVGG = networks.VGGLoss(self.gpu_ids)
# Names so we can breakout loss
self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake')
# initialize optimizers
# optimizer G
if opt.niter_fix_global > 0:
# 是否只使用outmost的局部生成器(local hencer)
import sys
if sys.version_info >= (3,0): # 判断python的版本号是否大于3.0
finetune_list = set() # 如果大于3.0,就建立一个集合,接受列表作为参数
else:
from sets import Set
finetune_list = Set()
params_dict = dict(self.netG.named_parameters()) # 建立一个字典,字典键值为生成器网络的所有参数,键为各参数对应的名字
params = [] # 建立一个列表
for key, value in params_dict.items(): # 遍历字典
if key.startswith('model' + str(opt.n_local_enhancers)):
# 如果遍历中遇到键值为model+n.local_enhancers
params += [value] # params就在列表中加入键值,此时记录的全部是local_hancers的相关参数
finetune_list.add(key.split('.')[0]) # 在.处将键值打断并生成列表输出[0]
print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
print('The layers that are finetuned are ', sorted(finetune_list))
else: # 如果不只使用局部生成器,就直接生成params
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
# 完成生成器优化器的定义
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# optimizer D
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
# 编码部分,既包含boundary map的编码,也包含feature的编码
if self.opt.label_nc == 0: # 如果label_nc = 0,这也就表示我们的boundary map中没有对象
input_label = label_map.data.cuda()
else:
# create one-hot vector for label map
size = label_map.size()
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
if self.opt.data_type == 16:
input_label = input_label.half()
# 使用边缘图结合label进行输入时
# get edges from instance map
if not self.opt.no_instance:
inst_map = inst_map.data.cuda()
edge_map = self.get_edges(inst_map)
input_label = torch.cat((input_label, edge_map), dim=1)
# 在维度1上进行拼接,由于pytorch中tensor为NCHW,那么也就是C加1,这与我们之前输入时给input_nc+1是一致的
input_label = Variable(input_label, volatile=infer) # 将tensor转换为variable
# real images for training
if real_image is not None:
real_image = Variable(real_image.data.cuda())
# instance map for feature encoding
if self.use_features:
# 如果使用instance-wise
# get precomputed feature maps
if self.opt.load_features:
# 如果使用已经预处理的特征图
feat_map = Variable(feat_map.data.cuda())
if self.opt.label_feat:
# 如果使用label编码特征图
inst_map = label_map.cuda()
return input_label, inst_map, real_image, feat_map
# 返回输入标签图,label编码图,真值图以及预处理的特征图(没有就返回None)
def discriminate(self, input_label, test_image, use_pool=False):
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
# 将输入标签和测试图拼接一起输入网络
if use_pool:
fake_query = self.fake_pool.query(input_concat)
return self.netD.forward(fake_query)
else:
# 正向传播返回结果
return self.netD.forward(input_concat)
def forward(self, label, inst, image, feat, infer=False):
# Encode Inputs
input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
# 首先定义了encoder网络的输入(同时定义了label_input)
# Fake Generation
if self.use_features:
if not self.opt.load_features:
# 如果使用encoder并不使用预处理特征图,就把真值图和特征图输入encoder,返回encoder得到的特征图
feat_map = self.netE.forward(real_image, inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
# 将输入label(此时已经包含了boundary map)与特征图拼接(此时为真正的输入标签图)
else:
input_concat = input_label
fake_image = self.netG.forward(input_concat) # 完成生成器正向传播
# Fake Detection and Loss
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
# 计算fake_loss
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
pred_real = self.discriminate(input_label, real_image)
# 计算real_loss
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
# 计算生成器GAN损失
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
# GAN feature matching loss
loss_G_GAN_Feat = 0
# 计算特征匹配损失
if not self.opt.no_ganFeat_loss:
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
D_weights = 1.0 / self.opt.num_D
for i in range(self.opt.num_D):
for j in range(len(pred_fake[i])-1):
loss_G_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
# VGG feature matching loss
loss_G_VGG = 0
# 计算感知损失
if not self.opt.no_vgg_loss:
loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
# Only return the fake_B image if necessary to save BW
# 最后返回所有损失
return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ]
def inference(self, label, inst, image=None):
# 生成图片(test中实现图像翻译)的用法
# Encode Inputs
image = Variable(image) if image is not None else None
# 是否有真值图
input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)
# 定义了encoder网络的输入(同时定义了label_input)
# Fake Generation
# 与上面一样完成输入图片的拼接
if self.use_features:
if self.opt.use_encoded_image:
# encode the real image to get feature map
feat_map = self.netE.forward(real_image, inst_map)
else:
# sample clusters from precomputed features
feat_map = self.sample_features(inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
# 正向传播
if torch.__version__.startswith('0.4'):
with torch.no_grad():
fake_image = self.netG.forward(input_concat)
else:
fake_image = self.netG.forward(input_concat)
return fake_image
def sample_features(self, inst):
# 读取预处理图片的特征生成特征图
# read precomputed feature clusters
cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
features_clustered = np.load(cluster_path, encoding='latin1').item()
# randomly sample from the feature clusters
inst_np = inst.cpu().numpy().astype(int)
feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])
for i in np.unique(inst_np):
label = i if i < 1000 else i//1000
if label in features_clustered:
feat = features_clustered[label]
cluster_idx = np.random.randint(0, feat.shape[0])
idx = (inst == int(i)).nonzero()
for k in range(self.opt.feat_num):
feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
if self.opt.data_type==16:
feat_map = feat_map.half()
return feat_map
def encode_features(self, image, inst):
# 计算特征值
image = Variable(image.cuda(), volatile=True)
# 将图像放在GPU上,并不计算梯度
feat_num = self.opt.feat_num
h, w = inst.size()[2], inst.size()[3]
# 设置获取特征的数量和特征图的大小
block_num = 32
feat_map = self.netE.forward(image, inst.cuda())
# 向前传播,生成特征图
inst_np = inst.cpu().numpy().astype(int)
# 将inst图转换为numpy类型
feature = {}
for i in range(self.opt.label_nc):
feature[i] = np.zeros((0, feat_num+1))
# 循环,让feature集合中的每一个元素都是形状为(0,feat_num+1)的列表
for i in np.unique(inst_np):
# 只循环inst数组中不同的项
label = i if i < 1000 else i//1000
idx = (inst == int(i)).nonzero()
num = idx.size()[0]
idx = idx[num//2,:]
val = np.zeros((1, feat_num+1))
# 储存特征值
for k in range(feat_num):
val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
val[0, feat_num] = float(num) / (h * w // block_num)
feature[label] = np.append(feature[label], val, axis=0)
return feature
def get_edges(self, t):
# 获取边缘图boundary map
edge = torch.cuda.ByteTensor(t.size()).zero_()
edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
if self.opt.data_type==16:
return edge.half()
else:
return edge.float()
def save(self, which_epoch):
# 保存权重文件的用法
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
if self.gen_features:
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)
def update_fixed_params(self):
# after fixing the global generator for a number of iterations, also start finetuning it
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
if self.opt.verbose:
print('------------ Now also finetuning global generator -----------')
def update_learning_rate(self):
# 更新学习率
lrd = self.opt.lr / self.opt.niter_decay # 确认每次学习率降低的多少
lr = self.old_lr - lrd
# 遍历,每一层的学习率都降低
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
# 更新数据
self.old_lr = lr
class InferenceModel(Pix2PixHDModel):
def forward(self, inp):
label, inst = inp
return self.inference(label, inst)