Oscar/run_gqa.py at master · microsoft/Oscar · GitHub
git clone https://github.com/microsoft/Oscar.git
git submodule init
git submodule update
0.无力吐槽
该模型一个特点,就是训练贼慢,而且还占很大的显存。我是指VQA 和GQA
比LXMERT 的轻量版本来说,实在显得太漫长了。一次VQA微调至少个4天。GQA 还没有算过!!!
目前只是运行了,但是内部代码并没有细看,下面就按照笔记形式分析一下。
整体模块还是很好看的,因为是基于Pytorch框架
1. 输入数据的处理
Oscar/task_utils.py at master · microsoft/Oscar · GitHub
作者的注释还是非常清晰易懂,这里就是将数据文件tsv 形式的整理成模型可以输入的格式,整理成InputInstance 规范话, 像我们VQA 只需要输入一个问题即可。所以这里text_b 是不需要的。同样,test 数据的输入也不要包含label 信息。
2. 基类读取tsv
每个处理tsv 文件,比如GQA,VQA 的都要继承此类,对数据进行统一接口,更加清晰。
3. 小类们
class VQATextProcessor(DataProcessor)
目前感觉 line 里面有几个key q 是问题,o 可能是其他文本,q_id是问题的id,img_key 是图片的id, an 是问题的答案, s 是问题分数。具体的以后补充,这里设置输出语句就能看到。
_truncate_seq_pair尽可能保证均匀截断,长句子更容易截断,为了保证两个文本信息都足够保留,而不是只看总长度。
class VQATextAProcessor(DataProcessor)
这个类与上面的类区别只在于text_b 是否有值,其他类不再赘述,都长的差不多。
主函数有个processors就是这里的值,它是一种类的dict。
此函数主要是为了将text_a b 整合好,准备真正的输入到网络的 batch
开始计算总长度不能超过我们设置的长度,且将文本数据转化成切分词,后面有转换成对应的token 数字的方法,input_ids,这里没有截图 ,在计算长度时去除一些特定的token。
input_mask 的计算方式按照input_ids 是否为0 ,如果为0 无意义,不为0 则为有效的单词
总共bert 有效的三种输入,input_mask,input_ids,segment_ids,抓住这三个就行
作者举个例子,关于type_ids 的值 也就是segment_ids
padding 最后在三个input 的后面统一添加
4.主要类:modeling_bert.py
总共三个文件,除了上面的task_utils.py, run_gqa.py 还有一个modeling_bert.py 其他可以略看。