class EMA():
def __init__(self, decay=0.999):
self.decay = decay
self.shadow = {}
def register(self, name, val):
self.shadow[name] = val.cpu().detach()
def get(self, name):
return self.shadow[name]
def update(self, name, x):
assert name in self.shadow
new_average = (1.0 - self.decay) * x.cpu().detach() + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
ema = EMA(args.ema_decay)
for name, param in network.named_parameters():
if param.requires_grad:
ema.register(name, param.data)
def save_checkpoint(state, epoch, dst, is_best):
filename = os.path.join(dst, str(args.start_epoch + epoch)) + '.pth.tar'
torch.save(state, filename)
if is_best:
dst_best = os.path.join(dst, 'model_best', str(epoch)) + '.pth.tar'
shutil.copyfile(filename, dst_best)
#args.num_epoches是总轮数,args.start_epoch是预训练轮数
for epoch in range(args.num_epoches - args.start_epoch):
#保存network,network_ema,optimizer,W,epoch到checkpoint_dir。
state = {'network':network.state_dict(),'network_ema':ema.shadow,'optimizer':optimizer.state_dict(),'W': compute_loss.W, 'epoch': args.start_epoch + epoch}
save_checkpoint(state, epoch, args.checkpoint_dir, False)
logging.info('Epoch: [{}|{}], train_time: {:.3f}, train_loss:{:.3f}'.format(args.start_epoch + epoch, args.num_epoches,train_time, train_loss))
logging.info('image_precision: {:.3f}, text_precision: {:.3f}'.format(image_precision, text_precision))
adjust_lr(optimizer, args.start_epoch + epoch, args)
scheduler.step()
for param in optimizer.param_groups:
print('lr:{}'.format(param['lr']))
break
logging.info('Train done')
logging.info(args.checkpoint_dir)
logging.info(args.log_dir)
#初始化network
def network_config(args, split='train', param=None, resume=False, model_path=None, ema=False):
network = Model(args)
#使用多个GPU
network = nn.DataParallel(network).cuda()
#增加运行效率
cudnn.benchmark = True
args.start_epoch = 0
1、resume: 'whether or not to restore the pretrained whole model' 使用方法 !python train.py --resume resume为true时,加载model_path里的.tar文件,args.start_epoch等于文件保存的epoch+1,network_dict是文件里的network, 如果ema为true,向network_dict添加network_ema,然后network加载network_dict
if resume:
#train_config里的args.model_path为model_path
directory.check_file(model_path, 'model_file')
checkpoint = torch.load(model_path)
args.start_epoch = checkpoint['epoch'] + 1
network_dict = checkpoint['network']
#是否使用指数加权平均
if ema:
#向日志中添加信息
logging.info('==> EMA Loading')
network_dict.update(checkpoint['network_ema'])
network.load_state_dict(network_dict,False)
print('==> Loading checkpoint "{}"'.format(model_path))
resume为false,network_dict是初始化的network.state_dict,加载start往后的参数
else:
# pretrained
if model_path is not None:
print('==> Loading from pretrained models')
network_dict = network.state_dict()
if args.image_model == 'mobilenet_v1':
cnn_pretrained = torch.load(model_path)['state_dict']
start = 7
else:
cnn_pretrained = torch.load(model_path)
start = 0
# process keyword of pretrained model
prefix = 'module.image_model.'
pretrained_dict = {prefix + k[start:] :v for k,v in cnn_pretrained.items()}
# 将pretained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in network_dict}
network_dict.update(pretrained_dict)
network.load_state_dict(network_dict)
#process optimizer params
if split == 'test':
optimizer = None
else:
# optimizer
# different params for different part
cnn_params = list(map(id, network.module.image_model.parameters()))
other_params = filter(lambda p: id(p) not in cnn_params, network.parameters())
other_params = list(other_params)
if param is not None:
other_params.extend(list(param))
param_groups = [{'params':other_params},
{'params':network.module.image_model.parameters(), 'weight_decay':args.wd}]
optimizer = torch.optim.Adam(
param_groups,
lr = args.lr, betas=(args.adam_alpha, args.adam_beta), eps=args.epsilon)
if resume:
optimizer.load_state_dict(checkpoint['optimizer'])
print('Total params: %2.fM' % (sum(p.numel() for p in network.parameters()) / 1000000.0))
# seed
manualSeed = random.randint(1, 10000)
random.seed(manualSeed)
np.random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed_all(manualSeed)
return network, optimizer