目录
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pickle
from utils.loss import LanguageModelCriterion, CosineCriterion, SoftCriterion
from eval import eval_fn
from configs.settings import TotalConfigs
from models.hungary import HungarianMatcher
def _get_src_permutation_idx(indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
def _get_tgt_permutation_idx(indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
def train_fn(cfgs: TotalConfigs, model_name: str, model: nn.Module, matcher: HungarianMatcher, train_loader, valid_loader, device):
optimizer = optim.Adam(model.parameters(), lr=cfgs.train.learning_rate)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, cfgs.train.max_epochs, eta_min=0, last_epoch=-1)
language_loss = LanguageModelCriterion()
soft_loss = SoftCriterion()
cos_loss = CosineCriterion()
language_loss.to(device)
best_score, cnt = None, 0
loss_store, loss_entity, loss_predicate, loss_sentence, loss_xe, loss_soft_target = [], [], [], [], [], []
with open(cfgs.data.idx2word_path, 'rb') as f:
idx2word = pickle.load(f)
with open(cfgs.data.vid2groundtruth_path, 'rb') as f:
vid2groundtruth = pickle.load(f)
print('===================Training begin====================')
print(cfgs.train.save_checkpoints_path)
for epoch in range(cfgs.train.max_epochs):
print('\n{}[EPOCH {}]{}'.format('='*15, epoch, '='*15))
for i, (feature2ds, feature3ds, objects, object_masks, vp_semantics, caption_semantics, numberic_caps, masks, captions, nouns, vids, vocab_ids, vocab_probs, fillmasks) in enumerate(train_loader):
cnt += 1
feature2ds = feature2ds.to(device)
feature3ds = feature3ds.to(device)
objects = objects.to(device)
object_masks = object_masks.to(device)
vp_semantics = vp_semantics.to(device)
caption_semantics = caption_semantics.to(device)
numberic_caps = numberic_caps.to(device)
masks = masks.to(device)
vocab_ids = vocab_ids.to(device) if vocab_ids is not None else None
vocab_probs = vocab_probs.to(device) if vocab_probs is not None else None
fillmasks = fillmasks.to(device) if fillmasks is not None else None
# bsz, sample_numb, obj_numb, obj_dim = objects.shape
# objects = objects.reshape([bsz, sample_numb * obj_numb, obj_dim])
optimizer.zero_grad()
preds, objects_pending, action_pending, video_pending = model(objects, object_masks, feature2ds, feature3ds, numberic_caps)
xe_loss, s_loss, ent_loss, pred_loss, sent_loss = None, None, None, None, None
# cross entropy loss
loss_hard = language_loss(preds, numberic_caps, masks, cfgs.dict.eos_idx)
loss = loss_hard
xe_loss = loss_hard.detach().item()
# soft loss
if cfgs.train.lambda_soft > 0:
loss_soft = soft_loss(preds, vocab_ids, vocab_probs, fillmasks)
loss = loss + loss_soft * cfgs.train.lambda_soft
s_loss = loss_soft.detach().item()
# object module loss
if cfgs.train.lambda_entity > 0:
indices = matcher(objects_pending, nouns)
src_idx = _get_src_permutation_idx(indices)
objects = objects_pending[src_idx]
targets = torch.cat([t['vec'][i] for t, (_, i) in zip(nouns, indices)], dim=0).to(device)
if np.any(np.isnan(objects.detach().cpu().numpy())):
raise RuntimeError
object_loss = cos_loss(objects, targets)
loss = loss + object_loss * cfgs.train.lambda_entity
ent_loss = object_loss.detach().item()
# action module loss
if cfgs.train.lambda_predicate > 0:
action_loss = cos_loss(action_pending, vp_semantics)
loss = loss + action_loss * cfgs.train.lambda_predicate
pred_loss = action_loss.detach().item()
# video module loss
if cfgs.train.lambda_sentence > 0:
sent_loss = cos_loss(video_pending, caption_semantics)
loss = loss + sent_loss * cfgs.train.lambda_sentence
sent_loss = sent_loss.detach().item()
loss.backward()
loss_store.append(loss.detach().item())
loss_xe.append(xe_loss)
loss_entity.append(ent_loss)
loss_predicate.append(pred_loss)
loss_sentence.append(sent_loss)
loss_soft_target.append(s_loss)
nn.utils.clip_grad_norm_(model.parameters(), cfgs.train.grad_clip)
optimizer.step()
if cnt % cfgs.train.visualize_every == 0:
loss_store, loss_xe, loss_entity, loss_predicate, loss_sentence, loss_soft_target = \
loss_store[-10:], loss_xe[-10:], loss_entity[-10:], loss_predicate[-10:], loss_sentence[-10:], loss_soft_target[-10:]
loss_value = np.array(loss_store).mean()
xe_value = np.array(loss_xe).mean() if loss_xe[0] is not None else 0
soft_value = np.array(loss_soft_target).mean() if loss_soft_target[0] is not None else 0
entity_value = np.array(loss_entity).mean() if loss_entity[0] is not None else 0
predicate_value = np.array(loss_predicate).mean() if loss_predicate[0] is not None else 0
sentence_value = np.array(loss_sentence).mean() if loss_sentence[0] is not None else 0
print('[EPOCH {};ITER {}]:loss[{:.3f}]=hard_loss[{:.3f}]*1+soft_loss[{:.3f}]*{:.2f}+entity[{:.3f}]*{:.2f}+predicate[{:.3f}]*{:.2f}+sentence[{:.3f}]*{:.2f}'
.format(epoch, i, loss_value, xe_value,
soft_value, cfgs.train.lambda_soft,
entity_value, cfgs.train.lambda_entity,
predicate_value, cfgs.train.lambda_predicate,
sentence_value, cfgs.train.lambda_sentence))
if cnt % cfgs.train.save_checkpoints_every == 0:
ckpt_path = cfgs.train.save_checkpoints_path
scores = eval_fn(model=model, loader=valid_loader, device=device,
idx2word=idx2word, save_on_disk=False, cfgs=cfgs,
vid2groundtruth=vid2groundtruth)
cider_score = scores['CIDEr']
if best_score is None or cider_score > best_score:
best_score = cider_score
torch.save(model.state_dict(), ckpt_path)
print('=' * 10,
'[EPOCH{epoch} iter{it}] :Best Cider is {bs}, Current Cider is {cs}'.
format(epoch=epoch, it=i, bs=best_score, cs=cider_score),
'=' * 10)
lr_scheduler.step()
print('===================Training is finished====================')
return model
首先加载训练集,得到14个值,前面说到了。以及加载id2word和vid2groundtruth。
观察这行代码
模型输入为objects, object_masks, feature2ds, feature3ds, numberic_caps,
输出为preds, objects_pending, action_pending, video_pending。
可以看到preds三个维度第一个batch64表示64句话,22个单词,每个单词维度12800,与词汇表对应。
接下来看第一个loss,也就是常规的交叉熵损失函数:
对应下列过程,target第一个索引也就是1全部填充为0,放入后面。x和sos对应,最后一个z和eos对应。
torch.gather(input,dim,index)函数,生成的tensor的shape为index.shape。
接下来看第二个loss
注意,这里的mask也就是fillmask,只mask了单词,不包含sos或者eso。
接下来看object_loss
首先看到matcher,实例化的匈牙利匹配算法,输入为objects_pending, nouns
class HungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
def __init__(self):
super(HungarianMatcher, self).__init__()
self.eps = 1e-6
@torch.no_grad()
def forward(self, salient_objects: Tensor, nouns_dict_list: list):
""" Performs the matching
Args:
salient_objects: (bsz, max_objects, word_dim)
nouns_dict_list: List[{'vec': nouns_vec, 'nouns': nouns}, ...]
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
bsz, max_objects = salient_objects.shape[:2]
device = salient_objects.device
sizes = [len(item['nouns']) for item in nouns_dict_list]
nouns_semantics = torch.cat([item['vec'][:len(item['nouns'])] for item in nouns_dict_list]).to(device) # (\sigma nouns, word_dim)
nouns_length = torch.norm(nouns_semantics, dim=-1, keepdim=True) # (\sigma nouns, 1)
salient_objects = salient_objects.flatten(0, 1) # (bsz * max_objects, word_dim)
salient_length = torch.norm(salient_objects, dim=-1, keepdim=True) # (bsz * max_objects, 1)
matrix_length = salient_length * nouns_length.permute([1, 0]) + self.eps # (bsz * max_objects, \sigma nouns)
cos_matrix = torch.mm(salient_objects, nouns_semantics.permute([1, 0])) # (bsz * max_objects, \sigma nouns)
cos_matrix = -cos_matrix / matrix_length # (bsz * max_objects, \sigma nouns)
cos_matrix = cos_matrix.view([bsz, max_objects, -1]) # (bsz, max_objects, \sigma nouns)
indices = [linear_sum_assignment(c[i].detach().cpu().numpy()) for i, c in enumerate(cos_matrix.split(sizes, -1))]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
输入的shape
省略…
最终返回索引值
元组第一个为object的索引,第二个为noun的索引
最终利用两个向量间的余弦距离作为loss,目的是最大化余弦距离
同理,action_loss和caption_loss一样
剩下的代码就是反向传播,保存模型等等了
loss.backward()
loss_store.append(loss.detach().item())
loss_xe.append(xe_loss)
loss_entity.append(ent_loss)
loss_predicate.append(pred_loss)
loss_sentence.append(sent_loss)
loss_soft_target.append(s_loss)
nn.utils.clip_grad_norm_(model.parameters(), cfgs.train.grad_clip)
optimizer.step()
if cnt % cfgs.train.visualize_every == 0:
loss_store, loss_xe, loss_entity, loss_predicate, loss_sentence, loss_soft_target = \
loss_store[-10:], loss_xe[-10:], loss_entity[-10:], loss_predicate[-10:], loss_sentence[-10:], loss_soft_target[-10:]
loss_value = np.array(loss_store).mean()
xe_value = np.array(loss_xe).mean() if loss_xe[0] is not None else 0
soft_value = np.array(loss_soft_target).mean() if loss_soft_target[0] is not None else 0
entity_value = np.array(loss_entity).mean() if loss_entity[0] is not None else 0
predicate_value = np.array(loss_predicate).mean() if loss_predicate[0] is not None else 0
sentence_value = np.array(loss_sentence).mean() if loss_sentence[0] is not None else 0
print('[EPOCH {};ITER {}]:loss[{:.3f}]=hard_loss[{:.3f}]*1+soft_loss[{:.3f}]*{:.2f}+entity[{:.3f}]*{:.2f}+predicate[{:.3f}]*{:.2f}+sentence[{:.3f}]*{:.2f}'
.format(epoch, i, loss_value, xe_value,
soft_value, cfgs.train.lambda_soft,
entity_value, cfgs.train.lambda_entity,
predicate_value, cfgs.train.lambda_predicate,
sentence_value, cfgs.train.lambda_sentence))
if cnt % cfgs.train.save_checkpoints_every == 0:
ckpt_path = cfgs.train.save_checkpoints_path
scores = eval_fn(model=model, loader=valid_loader, device=device,
idx2word=idx2word, save_on_disk=False, cfgs=cfgs,
vid2groundtruth=vid2groundtruth)
cider_score = scores['CIDEr']
# if best_score is None or cider_score > best_score:
# best_score = cider_score
# torch.save(model.state_dict(), ckpt_path)
torch.save(model.state_dict(), ckpt_path)
print('=' * 10,
'[EPOCH{epoch} iter{it}] :Best Cider is {bs}, Current Cider is {cs}'.
format(epoch=epoch, it=i, bs=best_score, cs=cider_score),
'=' * 10)