lxmert部分代码的一点理解

编写不易如果觉得不错,麻烦关注一下~

代码连接:https://github.com/airsplay/lxmert

LXMERT — transformers 3.2.0 documentation[huggingface 库已经收纳LXMERT模型!!!]
作者在readme 中写道:The logs and model snapshots will be saved under folder snap/vqa/vqa_lxr955. The validation result after training will be around 69.7% to 70.2%. 结果是可以复现的。

一、视觉问答

1.从训练vqa的命令看起:

bash run/vqa_finetune.bash 0 vqa_lxr955 

在vqa_data 中获取train.json 与图像特征组成dataset,并通过get_item提取每个batch

"""

    A VQA data example in json file:
        {
            "answer_type": "other",
            "img_id": "COCO_train2014_000000458752",
            "label": {
                "net": 1
            },
            "question_id": 458752000,
            "question_type": "what is this",
            "sent": "What is this photo taken looking through?"
        }

An example in obj36 tsv:
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
              "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
FIELDNAMES would be keys in the dict returned by load_obj_tsv.
"""
class VQATorchDataset(Dataset):
    def __getitem__(self, item: int):
        datum = self.data[item]

        img_id = datum['img_id']
        ques_id = datum['question_id']
        ques = datum['sent']

        # Get image info
        img_info = self.imgid2img[img_id]
        obj_num = img_info['num_boxes']
        feats = img_info['features'].copy()
        boxes = img_info['boxes'].copy()
        assert obj_num == len(boxes) == len(feats)

        # Normalize the boxes (to 0 ~ 1)
        img_h, img_w = img_info['img_h'], img_info['img_w']
        boxes = boxes.copy()
        boxes[:, (0, 2)] /= img_w
        boxes[:, (1, 3)] /= img_h
        np.testing.assert_array_less(boxes, 1+1e-5)
        np.testing.assert_array_less(-boxes, 0+1e-5)

        # Provide label (target)
        if 'label' in datum:
            label = datum['label']
            target = torch.zeros(self.raw_dataset.num_answers)
            for ans, score in label.items():
                target[self.raw_dataset.ans2label[ans]] = score
            return ques_id, feats, boxes, ques, target
        else:
            return ques_id, feats, boxes, ques

一个训练样例对应一个问题id, 一个图片的目标框的特征向量,目标框的坐标,问题语句,标签

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers)

说明整个模型是VQAModel 搭起来。

2.模型基本组件

下面可以看到使用self.model(feats, boxes, sent)方法获得整个LXMERT 模型的输出值,并与真值进行loss 计算,并进行反向传播。

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)

        best_valid = 0.
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()
                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

VQAModel 类如下,从代码中可以看出,模型总共有两个组件一个是LXRTEncoder, 一个是logit_fc。这里我们细看一下LXRTEncoder。 这里面的lxrt_encoder 在forward 函数中的返回值为单个,其实是这里默认使用了模型中的交互向量,红框指示的位置,但是作者在编写代码时也有其他特征返回的语句,并用元组进行表示(,).

# Max length including <bos> and <eos>
MAX_VQA_LENGTH = 20


class VQAModel(nn.Module):
    def __init__(self, num_answers):
        super().__init__()
        
        # Build LXRT encoder
        self.lxrt_encoder = LXRTEncoder(
            args,
            max_seq_length=MAX_VQA_LENGTH
        )
        hid_dim = self.lxrt_encoder.dim
        
        # VQA Answer heads
        self.logit_fc = nn.Sequential(
            nn.Linear(hid_dim, hid_dim * 2),
            GeLU(),
            BertLayerNorm(hid_dim * 2, eps=1e-12),
            nn.Linear(hid_dim * 2, num_answers)
        )
        self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)

    def forward(self, feat, pos, sent):
        """
        b -- batch_size, o -- object_number, f -- visual_feature_size
        :param feat: (b, o, f)
        :param pos:  (b, o, 4)
        :param sent: (b,) Type -- list of string
        :param leng: (b,) Type -- int numpy array
        :return: (b, num_answer) The logit of each answers.
        """
        x = self.lxrt_encoder(sent, (feat, pos))
        logit = self.logit_fc(x)

        return logit

3.LXRTEncoder部分细节

在forward中可以看到convert_sents_to_features函数,将问题语句进行分词处理

import os

import torch
import torch.nn as nn

from lxrt.tokenization import BertTokenizer
from lxrt.modeling import LXRTFeatureExtraction as VisualBertForLXRFeature, VISUAL_CONFIG
class LXRTEncoder(nn.Module):
    def __init__(self, args, max_seq_length, mode='x'):
        super().__init__()
        self.max_seq_length = max_seq_length
        set_visual_config(args)

        # Using the bert tokenizer
        self.tokenizer = BertTokenizer.from_pretrained(
            "bert-base-uncased",
            do_lower_case=True
        )

        # Build LXRT Model
        self.model = VisualBertForLXRFeature.from_pretrained(
            "bert-base-uncased",
            mode=mode
        )

        if args.from_scratch:
            print("initializing all the weights")
            self.model.apply(self.model.init_bert_weights)

    def multi_gpu(self):
        self.model = nn.DataParallel(self.model)

    @property
    def dim(self):
        return 768

    def forward(self, sents, feats, visual_attention_mask=None):
        train_features = convert_sents_to_features(
            sents, self.max_seq_length, self.tokenizer)

        input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda()
        input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda()
        segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda()

        output = self.model(input_ids, segment_ids, input_mask,
                            visual_feats=feats,
                            visual_attention_mask=visual_attention_mask)
        return output

convert_sents_to_features函数:将问题语句 进行分词处理,得到token_a, 并根据最大序列长度18 对问题进行截断处理。并将每个分词对应的id 号 添加到input_id列表中,不足18 利用0 填充。

def convert_sents_to_features(sents, max_seq_length, tokenizer):
    """Loads a data file into a list of `InputBatch`s."""

    features = []
    for (i, sent) in enumerate(sents):
        tokens_a = tokenizer.tokenize(sent.strip())

        # Account for [CLS] and [SEP] with "- 2"
        if len(tokens_a) > max_seq_length - 2:
            tokens_a = tokens_a[:(max_seq_length - 2)]
        
        # Keep segment id which allows loading BERT-weights.
        tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
        segment_ids = [0] * len(tokens)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding = [0] * (max_seq_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        segment_ids += padding

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        features.append(
                InputFeatures(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids))
    return features

我们还是从model 中进行着手: 即VisualBertForLXRFeature , 原名:LXRTFeatureExtraction。在前面可以看出mode 传入的时候是“x”,所以默认返回的是pooled_output

self.model = VisualBertForLXRFeature.from_pretrained(
            "bert-base-uncased",
            mode=mode
        )

class LXRTFeatureExtraction(BertPreTrainedModel):
    """
    BERT model for classification.
    """
    def __init__(self, config, mode='lxr'):
        """
        :param config:
        :param mode:  Number of visual layers
        """
        super().__init__(config)
        self.bert = LXRTModel(config)
        self.mode = mode
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, visual_feats=None,
                visual_attention_mask=None):
        feat_seq, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
                                            visual_feats=visual_feats,
                                            visual_attention_mask=visual_attention_mask)
        if 'x' == self.mode:
            return pooled_output
        elif 'x' in self.mode and ('l' in self.mode or 'r' in self.mode):
            return feat_seq, pooled_output
        elif 'l' in self.mode or 'r' in self.mode:
            return feat_seq

4.LXRTModel部分细节

下钻到核心模型 self.bert = LXRTModel(config),从forward 可以看出将lang_feats进行处理可以获得cross_feats 红框标注的地方。

class LXRTModel(BertPreTrainedModel):
    """LXRT Model."""

    def __init__(self, config):
        super().__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = LXRTEncoder(config)
        self.pooler = BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None,
                visual_feats=None, visual_attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # Process the visual attention mask
        if visual_attention_mask is not None:
            extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2)
            extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
            extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
        else:
            extended_visual_attention_mask = None

        # Positional Word Embeddings
        embedding_output = self.embeddings(input_ids, token_type_ids)

        # Run LXRT backbone
        lang_feats, visn_feats = self.encoder(
            embedding_output,
            extended_attention_mask,
            visn_feats=visual_feats,
            visn_attention_mask=extended_visual_attention_mask)
        pooled_output = self.pooler(lang_feats)

        return (lang_feats, visn_feats), pooled_output

BertPooler 类

class BertPooler(nn.Module):
    def __init__(self, config):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

5.LXRTEncoder部分细节

这里就是模型堆叠的具体方式,首先通过两个模态的堆叠,最后按照交互式堆叠几层后,返回lang_feats, visn_feats

class LXRTEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        # Obj-level image embedding layer
        self.visn_fc = VisualFeatEncoder(config)

        # Number of layers
        self.num_l_layers = VISUAL_CONFIG.l_layers
        self.num_x_layers = VISUAL_CONFIG.x_layers
        self.num_r_layers = VISUAL_CONFIG.r_layers
        print("LXRT encoder with %d l_layers, %d x_layers, and %d r_layers." %
              (self.num_l_layers, self.num_x_layers, self.num_r_layers))

        # Layers
        # Using self.layer instead of self.l_layer to support loading BERT weights.
        self.layer = nn.ModuleList(
            [BertLayer(config) for _ in range(self.num_l_layers)]
        )
        self.x_layers = nn.ModuleList(
            [LXRTXLayer(config) for _ in range(self.num_x_layers)]
        )
        self.r_layers = nn.ModuleList(
            [BertLayer(config) for _ in range(self.num_r_layers)]
        )

    def forward(self, lang_feats, lang_attention_mask,
                visn_feats, visn_attention_mask=None):
        # Run visual embedding layer
        # Note: Word embedding layer was executed outside this module.
        #       Keep this design to allow loading BERT weights.
        visn_feats = self.visn_fc(visn_feats)

        # Run language layers
        for layer_module in self.layer:
            lang_feats = layer_module(lang_feats, lang_attention_mask)

        # Run relational layers
        for layer_module in self.r_layers:
            visn_feats = layer_module(visn_feats, visn_attention_mask)

        # Run cross-modality layers
        for layer_module in self.x_layers:
            lang_feats, visn_feats = layer_module(lang_feats, lang_attention_mask,
                                                  visn_feats, visn_attention_mask)

        return lang_feats, visn_feats

6.具体看一个LXMERT VQA 数据集的初始化init函数

这里重要的看点就是如何将问题和图片建立一一对应的关系,通过imageid question实体。

class VQADataset:
    """
    A VQA data example in json file:
        {
            "answer_type": "other",
            "img_id": "COCO_train2014_000000458752",
            "label": {
                "net": 1
            },
            "question_id": 458752000,
            "question_type": "what is this",
            "sent": "What is this photo taken looking through?"
        }
    """
    def __init__(self, splits: str):
        self.name = splits
        self.splits = splits.split(',')

        # Loading datasets
        self.data = []
        for split in self.splits:
            self.data.extend(json.load(open("data/vqa/%s.json" % split)))
        print("Load %d data from split(s) %s." % (len(self.data), self.name))

        # Convert list to dict (for evaluation)
        self.id2datum = {
            datum['question_id']: datum
            for datum in self.data
        }

        # Answers
        self.ans2label = json.load(open("data/vqa/trainval_ans2label.json"))
        self.label2ans = json.load(open("data/vqa/trainval_label2ans.json"))
        assert len(self.ans2label) == len(self.label2ans)

    @property
    def num_answers(self):
        return len(self.ans2label)

    def __len__(self):
        return len(self.data)


"""
An example in obj36 tsv:
FIELDNAMES = ["img_id", "img_h", "img_w", "objects_id", "objects_conf",
              "attrs_id", "attrs_conf", "num_boxes", "boxes", "features"]
FIELDNAMES would be keys in the dict returned by load_obj_tsv.
"""
class VQATorchDataset(Dataset):
    def __init__(self, dataset: VQADataset):
        super().__init__()
        self.raw_dataset = dataset

        if args.tiny:
            topk = TINY_IMG_NUM
        elif args.fast:
            topk = FAST_IMG_NUM
        else:
            topk = None

        # Loading detection features to img_data
        img_data = []
        for split in dataset.splits:
            # Minival is 5K images in MS COCO, which is used in evaluating VQA/LXMERT-pre-training.
            # It is saved as the top 5K features in val2014_***.tsv
            load_topk = 5000 if (split == 'minival' and topk is None) else topk
            img_data.extend(load_obj_tsv(
                os.path.join(MSCOCO_IMGFEAT_ROOT, '%s_obj36.tsv' % (SPLIT2NAME[split])),
                topk=load_topk))

        # Convert img list to dict
        self.imgid2img = {}
        for img_datum in img_data:
            self.imgid2img[img_datum['img_id']] = img_datum

        # Only kept the data with loaded image features
        self.data = []
        for datum in self.raw_dataset.data:
            if datum['img_id'] in self.imgid2img:
                self.data.append(datum)
        print("Use %d data in torch dataset" % (len(self.data)))
        print()

具体代码:大佬的github 部分

从最后一句可以理解到作者是将问题 作为主要研究对象,按照image 分配的形式,将问题根据image的split 进行分配。然后一起训练网络

https://github.com/airsplay/lxmert/blob/master/src/pretrain/lxmert_pretrain.py

tsv 一个item 展示

Start to load Faster-RCNN detected objects from data/vg_gqa_imgfeat/vg_gqa_obj36.tsv
tesv: 2346776
OrderedDict([('img_id', '2346776'), ('img_h', 375), ('img_w', 500), ('objects_id', array([  72,  164,   50,  222,   98,   72,  100,  222,  781,   72,  222,
        222,  164,  164,  743,  177,  222,  100,   50,  100,  781,  960,
        743,  743,  181,  937, 1083,  781,   72,  743,   50,  743,  743,
        743,  743,  177])), ('objects_conf', array([0.84230465, 0.7640681 , 0.62371355, 0.5948998 , 0.5725832 ,
       0.4917991 , 0.48963746, 0.45057136, 0.44322845, 0.44004816,
       0.36180338, 0.36032832, 0.35463902, 0.3034201 , 0.2875118 ,
       0.28693637, 0.3090751 , 0.39253405, 0.45693383, 0.29296422,
       0.39178517, 0.2381179 , 0.23111628, 0.2177636 , 0.19504195,
       0.18673702, 0.17591234, 0.1691048 , 0.20077291, 0.1562024 ,
       0.5223265 , 0.21987833, 0.14224717, 0.14019318, 0.24196841,
       0.14741221], dtype=float32)), ('attrs_id', array([  6,   4, 209,  59,  11,   6,   8,   8,   7,   6,   8,   8,   7,
         6,   7, 310,   8,   8, 209,   8,   7,   7,   9,   7,  11,   7,
         9,   7,   7,   7, 209,   7,   7,  14,   7, 310])), ('attrs_conf', array([0.22624704, 0.32726702, 0.13061899, 0.15914878, 0.32727158,
       0.24480195, 0.1902841 , 0.15394422, 0.44051063, 0.1609596 ,
       0.15040548, 0.17329626, 0.12049368, 0.11614355, 0.29512778,
       0.12013788, 0.1064709 , 0.16638228, 0.13340168, 0.16281058,
       0.4279993 , 0.23218702, 0.16742417, 0.30998826, 0.08785141,
       0.25206673, 0.04508799, 0.32184234, 0.2762154 , 0.31993526,
       0.14825974, 0.3712262 , 0.37144113, 0.10767756, 0.13063323,
       0.09392384], dtype=float32)), ('num_boxes', 36), ('boxes', array([[  0.        ,   0.        , 318.25012   , 246.10194   ],
       [164.52708   ,  28.153622  , 255.60443   , 104.17139   ],
       [390.57724   , 207.20079   , 443.83505   , 295.64233   ],
       [ 63.46903   , 313.6256    , 480.49667   , 374.375     ],
       [405.05627   , 248.53212   , 437.54138   , 285.7874    ],
       [115.59644   ,   0.        , 473.01917   , 139.74986   ],
       [ 13.417034  , 264.20096   , 387.16727   , 344.8604    ],
       [100.42142   , 298.48996   , 266.08328   , 355.73306   ],
       [ 22.863993  ,  44.98969   , 165.55823   , 103.78797   ],
       [152.09888   ,  33.8815    , 499.375     , 320.8988    ],
       [  0.        , 290.75757   , 171.65611   , 374.375     ],
       [145.82584   , 263.8986    , 499.375     , 351.36533   ],
       [159.39641   ,   0.        , 403.09973   , 346.4812    ],
       [108.78353   ,  22.244396  , 361.98486   , 141.45567   ],
       [  4.474478  , 119.675385  , 275.90958   , 240.17114   ],
       [ 61.62403   , 264.07626   ,  78.333786  , 280.04562   ],
       [145.31396   , 203.91693   , 499.375     , 374.375     ],
       [  0.        , 218.89786   , 313.74792   , 374.375     ],
       [406.10648   , 223.15338   , 444.9298    , 272.0626    ],
       [117.48062   , 256.92874   , 499.375     , 327.50766   ],
       [  0.75047016,  31.080353  , 137.95702   ,  86.47684   ],
       [187.85034   ,  83.33661   , 467.14774   , 237.32494   ],
       [  5.7551193 , 185.9704    , 360.8667    , 275.5915    ],
       [227.82983   ,  55.33244   , 475.3305    , 184.35051   ],
       [413.00217   , 222.2033    , 442.94592   , 258.3501    ],
       [233.83167   ,  67.78333   , 404.4716    , 212.4943    ],
       [ 12.081871  , 221.60048   , 444.34238   , 289.43762   ],
       [  0.        ,  29.039692  , 283.83832   , 134.72557   ],
       [  1.9217873 ,  34.983826  , 236.93727   , 152.94627   ],
       [  0.        ,   0.        , 434.55768   ,  69.68215   ],
       [358.88797   , 230.3418    , 459.61545   , 292.2316    ],
       [  0.        ,  17.482538  , 154.04506   , 144.4485    ],
       [  6.5841866 ,   0.        , 444.4574    ,  46.168583  ],
       [ 74.011     , 192.00005   , 441.32007   , 278.66513   ],
       [ 27.20026   , 163.97757   , 365.32355   , 268.5604    ],
       [ 51.17173   , 265.21695   ,  83.57674   , 285.64877   ]],
      dtype=float32)), ('features', array([[3.2041940e-01, 0.0000000e+00, 1.9866660e-02, ..., 0.0000000e+00,
        3.5923147e-01, 8.0630213e-01],
       [1.7286272e+00, 1.6732953e-02, 0.0000000e+00, ..., 0.0000000e+00,
        8.6231306e-03, 3.6284101e+00],
       [2.0331763e-02, 0.0000000e+00, 0.0000000e+00, ..., 2.8744887e-03,
        1.7182362e-01, 1.5960330e-02],
       ...,
       [2.9783529e-01, 5.0786036e-01, 0.0000000e+00, ..., 4.3921422e-02,
        4.1568637e-01, 4.5798761e-01],
       [2.1963201e-01, 2.4200399e-01, 1.4872210e-04, ..., 0.0000000e+00,
        1.3755366e-01, 6.0404956e-02],
       [2.3011213e-02, 4.9133492e-01, 8.4014311e-02, ..., 0.0000000e+00,
        0.0000000e+00, 4.8337227e-01]], dtype=float32))])
 

np.testing.assert_array_less 的用法 

检测一个array 里面是否有小于第二个参数,如果小于,则报错

这个函数在lxmert 中是将box 归一化,然后检测归一化后的数值都要小于1.00001

Mismatched 说明总共多少个元素不符合这个条件!

  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 11
    评论
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值