HMN的train解析

目录

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pickle

from utils.loss import LanguageModelCriterion, CosineCriterion, SoftCriterion
from eval import eval_fn
from configs.settings import TotalConfigs
from models.hungary import HungarianMatcher


def _get_src_permutation_idx(indices):
    # permute predictions following indices
    batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
    src_idx = torch.cat([src for (src, _) in indices])
    return batch_idx, src_idx


def _get_tgt_permutation_idx(indices):
    # permute targets following indices
    batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
    tgt_idx = torch.cat([tgt for (_, tgt) in indices])
    return batch_idx, tgt_idx


def train_fn(cfgs: TotalConfigs, model_name: str, model: nn.Module, matcher: HungarianMatcher, train_loader, valid_loader, device):
    optimizer = optim.Adam(model.parameters(), lr=cfgs.train.learning_rate)
    lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, cfgs.train.max_epochs, eta_min=0, last_epoch=-1)
    language_loss = LanguageModelCriterion()
    soft_loss = SoftCriterion()
    cos_loss = CosineCriterion()
    language_loss.to(device)
    best_score, cnt = None, 0
    loss_store, loss_entity, loss_predicate, loss_sentence, loss_xe, loss_soft_target = [], [], [], [], [], []

    with open(cfgs.data.idx2word_path, 'rb') as f:
        idx2word = pickle.load(f)
    with open(cfgs.data.vid2groundtruth_path, 'rb') as f:
        vid2groundtruth = pickle.load(f)

    print('===================Training begin====================')
    print(cfgs.train.save_checkpoints_path)
    for epoch in range(cfgs.train.max_epochs):
        print('\n{}[EPOCH {}]{}'.format('='*15, epoch, '='*15))

        for i, (feature2ds, feature3ds, objects, object_masks, vp_semantics, caption_semantics, numberic_caps, masks, captions, nouns, vids, vocab_ids, vocab_probs, fillmasks) in enumerate(train_loader):
            cnt += 1
            feature2ds = feature2ds.to(device)
            feature3ds = feature3ds.to(device)
            objects = objects.to(device)
            object_masks = object_masks.to(device)
            vp_semantics = vp_semantics.to(device)
            caption_semantics = caption_semantics.to(device)
            numberic_caps = numberic_caps.to(device)
            masks = masks.to(device)
            vocab_ids = vocab_ids.to(device) if vocab_ids is not None else None
            vocab_probs = vocab_probs.to(device) if vocab_probs is not None else None
            fillmasks = fillmasks.to(device) if fillmasks is not None else None

            # bsz, sample_numb, obj_numb, obj_dim = objects.shape
            # objects = objects.reshape([bsz, sample_numb * obj_numb, obj_dim])

            optimizer.zero_grad()

            preds, objects_pending, action_pending, video_pending = model(objects, object_masks, feature2ds, feature3ds, numberic_caps)
            xe_loss, s_loss, ent_loss, pred_loss, sent_loss = None, None, None, None, None
            
            # cross entropy loss
            loss_hard = language_loss(preds, numberic_caps, masks, cfgs.dict.eos_idx)
            loss = loss_hard
            xe_loss = loss_hard.detach().item()

            # soft loss
            if cfgs.train.lambda_soft > 0:
                loss_soft = soft_loss(preds, vocab_ids, vocab_probs, fillmasks)
                loss = loss + loss_soft * cfgs.train.lambda_soft
                s_loss = loss_soft.detach().item()

            # object module loss
            if cfgs.train.lambda_entity > 0:
                indices = matcher(objects_pending, nouns)
                src_idx = _get_src_permutation_idx(indices)
                objects = objects_pending[src_idx]
                targets = torch.cat([t['vec'][i] for t, (_, i) in zip(nouns, indices)], dim=0).to(device)
                if np.any(np.isnan(objects.detach().cpu().numpy())):
                    raise RuntimeError
                object_loss = cos_loss(objects, targets)
                loss = loss + object_loss * cfgs.train.lambda_entity
                ent_loss = object_loss.detach().item()

            # action module loss
            if cfgs.train.lambda_predicate > 0:
                action_loss = cos_loss(action_pending, vp_semantics)
                loss = loss + action_loss * cfgs.train.lambda_predicate
                pred_loss = action_loss.detach().item()

            # video module loss
            if cfgs.train.lambda_sentence > 0:
                sent_loss = cos_loss(video_pending, caption_semantics)
                loss = loss + sent_loss * cfgs.train.lambda_sentence
                sent_loss = sent_loss.detach().item()

            loss.backward()
            loss_store.append(loss.detach().item())
            loss_xe.append(xe_loss)
            loss_entity.append(ent_loss)
            loss_predicate.append(pred_loss)
            loss_sentence.append(sent_loss)
            loss_soft_target.append(s_loss)
            nn.utils.clip_grad_norm_(model.parameters(), cfgs.train.grad_clip)
            optimizer.step()

            if cnt % cfgs.train.visualize_every == 0:
                loss_store, loss_xe, loss_entity, loss_predicate, loss_sentence, loss_soft_target = \
                    loss_store[-10:], loss_xe[-10:], loss_entity[-10:], loss_predicate[-10:], loss_sentence[-10:], loss_soft_target[-10:]
                loss_value = np.array(loss_store).mean()
                xe_value = np.array(loss_xe).mean() if loss_xe[0] is not None else 0
                soft_value = np.array(loss_soft_target).mean() if loss_soft_target[0] is not None else 0
                entity_value = np.array(loss_entity).mean() if loss_entity[0] is not None else 0
                predicate_value = np.array(loss_predicate).mean() if loss_predicate[0] is not None else 0
                sentence_value = np.array(loss_sentence).mean() if loss_sentence[0] is not None else 0
                
                print('[EPOCH {};ITER {}]:loss[{:.3f}]=hard_loss[{:.3f}]*1+soft_loss[{:.3f}]*{:.2f}+entity[{:.3f}]*{:.2f}+predicate[{:.3f}]*{:.2f}+sentence[{:.3f}]*{:.2f}'
                .format(epoch, i, loss_value, xe_value, 
                        soft_value, cfgs.train.lambda_soft,
                        entity_value, cfgs.train.lambda_entity, 
                        predicate_value, cfgs.train.lambda_predicate, 
                        sentence_value, cfgs.train.lambda_sentence))

            if cnt % cfgs.train.save_checkpoints_every == 0:
                ckpt_path = cfgs.train.save_checkpoints_path
                scores = eval_fn(model=model, loader=valid_loader, device=device, 
                                idx2word=idx2word, save_on_disk=False, cfgs=cfgs, 
                                vid2groundtruth=vid2groundtruth)
                cider_score = scores['CIDEr']
                if best_score is None or cider_score > best_score:
                    best_score = cider_score
                    torch.save(model.state_dict(), ckpt_path)
                print('=' * 10,
                      '[EPOCH{epoch} iter{it}] :Best Cider is {bs}, Current Cider is {cs}'.
                      format(epoch=epoch, it=i, bs=best_score, cs=cider_score),
                      '=' * 10)

        lr_scheduler.step()
    print('===================Training is finished====================')
    return model

首先加载训练集,得到14个值,前面说到了。以及加载id2word和vid2groundtruth。
在这里插入图片描述
观察这行代码
在这里插入图片描述
模型输入为objects, object_masks, feature2ds, feature3ds, numberic_caps,
输出为preds, objects_pending, action_pending, video_pending。
在这里插入图片描述
可以看到preds三个维度第一个batch64表示64句话,22个单词,每个单词维度12800,与词汇表对应。
接下来看第一个loss,也就是常规的交叉熵损失函数:
在这里插入图片描述
在这里插入图片描述
对应下列过程,target第一个索引也就是1全部填充为0,放入后面。x和sos对应,最后一个z和eos对应。
在这里插入图片描述
在这里插入图片描述
torch.gather(input,dim,index)函数,生成的tensor的shape为index.shape。
接下来看第二个loss
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

注意,这里的mask也就是fillmask,只mask了单词,不包含sos或者eso。

接下来看object_loss
在这里插入图片描述
首先看到matcher,实例化的匈牙利匹配算法,输入为objects_pending, nouns

class HungarianMatcher(nn.Module):
    """This class computes an assignment between the targets and the predictions of the network
    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """

    def __init__(self):
        super(HungarianMatcher, self).__init__()
        self.eps = 1e-6

    @torch.no_grad()
    def forward(self, salient_objects: Tensor, nouns_dict_list: list):
        """ Performs the matching
        Args:
            salient_objects: (bsz, max_objects, word_dim)
            nouns_dict_list: List[{'vec': nouns_vec, 'nouns': nouns}, ...]
        Returns:
            A list of size batch_size, containing tuples of (index_i, index_j) where:
                - index_i is the indices of the selected predictions (in order)
                - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds:
                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        """
        bsz, max_objects = salient_objects.shape[:2]
        device = salient_objects.device
        sizes = [len(item['nouns']) for item in nouns_dict_list]
        nouns_semantics = torch.cat([item['vec'][:len(item['nouns'])] for item in nouns_dict_list]).to(device)  # (\sigma nouns, word_dim)
        nouns_length = torch.norm(nouns_semantics, dim=-1, keepdim=True)  # (\sigma nouns, 1)
        salient_objects = salient_objects.flatten(0, 1)  # (bsz * max_objects, word_dim)
        salient_length = torch.norm(salient_objects, dim=-1, keepdim=True)  # (bsz * max_objects, 1)
        matrix_length = salient_length * nouns_length.permute([1, 0]) + self.eps  # (bsz * max_objects, \sigma nouns)


        cos_matrix = torch.mm(salient_objects, nouns_semantics.permute([1, 0]))  # (bsz * max_objects, \sigma nouns)
        cos_matrix = -cos_matrix / matrix_length  # (bsz * max_objects, \sigma nouns)
        cos_matrix = cos_matrix.view([bsz, max_objects, -1])  # (bsz, max_objects, \sigma nouns)
        indices = [linear_sum_assignment(c[i].detach().cpu().numpy()) for i, c in enumerate(cos_matrix.split(sizes, -1))]

        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

输入的shape
在这里插入图片描述
在这里插入图片描述
省略…
最终返回索引值
在这里插入图片描述
元组第一个为object的索引,第二个为noun的索引
在这里插入图片描述
最终利用两个向量间的余弦距离作为loss,目的是最大化余弦距离
同理,action_loss和caption_loss一样
在这里插入图片描述
剩下的代码就是反向传播,保存模型等等了

loss.backward()
            loss_store.append(loss.detach().item())
            loss_xe.append(xe_loss)
            loss_entity.append(ent_loss)
            loss_predicate.append(pred_loss)
            loss_sentence.append(sent_loss)
            loss_soft_target.append(s_loss)
            nn.utils.clip_grad_norm_(model.parameters(), cfgs.train.grad_clip)
            optimizer.step()

            if cnt % cfgs.train.visualize_every == 0:
                loss_store, loss_xe, loss_entity, loss_predicate, loss_sentence, loss_soft_target = \
                    loss_store[-10:], loss_xe[-10:], loss_entity[-10:], loss_predicate[-10:], loss_sentence[-10:], loss_soft_target[-10:]
                loss_value = np.array(loss_store).mean()
                xe_value = np.array(loss_xe).mean() if loss_xe[0] is not None else 0
                soft_value = np.array(loss_soft_target).mean() if loss_soft_target[0] is not None else 0
                entity_value = np.array(loss_entity).mean() if loss_entity[0] is not None else 0
                predicate_value = np.array(loss_predicate).mean() if loss_predicate[0] is not None else 0
                sentence_value = np.array(loss_sentence).mean() if loss_sentence[0] is not None else 0
                
                print('[EPOCH {};ITER {}]:loss[{:.3f}]=hard_loss[{:.3f}]*1+soft_loss[{:.3f}]*{:.2f}+entity[{:.3f}]*{:.2f}+predicate[{:.3f}]*{:.2f}+sentence[{:.3f}]*{:.2f}'
                .format(epoch, i, loss_value, xe_value, 
                        soft_value, cfgs.train.lambda_soft,
                        entity_value, cfgs.train.lambda_entity, 
                        predicate_value, cfgs.train.lambda_predicate, 
                        sentence_value, cfgs.train.lambda_sentence))

            if cnt % cfgs.train.save_checkpoints_every == 0:
                ckpt_path = cfgs.train.save_checkpoints_path
                scores = eval_fn(model=model, loader=valid_loader, device=device, 
                                idx2word=idx2word, save_on_disk=False, cfgs=cfgs, 
                                vid2groundtruth=vid2groundtruth)
                cider_score = scores['CIDEr']
                # if best_score is None or cider_score > best_score:
                #     best_score = cider_score
                #     torch.save(model.state_dict(), ckpt_path)
                torch.save(model.state_dict(), ckpt_path)
                print('=' * 10,
                      '[EPOCH{epoch} iter{it}] :Best Cider is {bs}, Current Cider is {cs}'.
                      format(epoch=epoch, it=i, bs=best_score, cs=cider_score),
                      '=' * 10)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值