1 README.md:有两个参数比较重要,--model_size(可以选择large或base)和--use_gpt(决定是否使用GPT,即是否使用隐式知识)
1 train_KAT.py(训练主代码):
1 evaluate_KAT.py(测试主代码):
2 function1:evaluate(测试函数)
1 wikidata_ontology.pkl(这是一个dict数据,里面包括187308个键值对,其中键表示知识ID,值是一个实体+一个caption描述)
1 build_wikidata_entity_db.py(这个函数非常重要,用于显式的知识提取)
1 okvqa(数据集存放位置,这个数据集中共包括14031个图像,14055个问题):
2 train2014(训练数据集的所有构成):
3 train2014.pkl:这个是一个元组,包括两个元素:第一个元素是dict类型,包括9009个键值对,其中键为图像名字,值为question;第二个元素是dict类型,包括9009个键值对,其中键为图像名字,值为answer(包括10个元素)
3 wikidata_okvqa_train2014_topentities.pkl:这是一个字典,包括8998个键值对,其中键为图像名字,值为元组,长度为2:第一个值是列表,长度为30,每一个是一个元组,表示实体+对应知识描述;第二个值是元组,长度30,这里应该表示每条知识概率???
3 gpt3_okvqa_train2014_answers.pkl:这是一个字典,包括9009个键值对,其中键表示图像名字,值为列表,列表里面是元组,表示实体以及对应的知识描述
2 val2014(验证数据集的所有构成):
3 val2014.pkl:这个是一个元组,包括两个元素:第一个元素是dict类型,包括5046个键值对,其中键为图像名字,值为question;第二个元素是dict类型,包括5046个键值对,其中键为图像名字,值为answer(包括10个元素)
3 wikidata_okvqa_val2014_topentities.pkl:这是一个字典,包括5033个键值对,其中键为图像名字,值为元组,长度为2:第一个值是列表,长度为30,每一个是一个元组,表示实体+对应知识描述;第二个值是元组,长度30,这里应该表示每条知识概率???
3 gpt3_okvqa_val2014_answers.pkl:这是一个字典,包括5046个键值对,其中键表示图像名字,值为列表,列表里面是元组,表示实体以及对应的知识描述
1 src:
2 options.py(参数设置):
2 evaluation.py:
3 function1:okvqa_ems:
2 data.py(数据集读取和处理):
3 function1:load_okvqa_data(下载OKVQA数据集,最终得到examples,这是一个列表,每个元素表示一个图像文本对,是用dict呈现,包括id(图像id#问题id),question(问题),answers(答案),entities(显式知识,维基百科知识),gpt3(隐式知识))
3 function2:encode_passages(问题+知识编码)
3 class1:OkvqaDataset(对数据集做了处理,增加了前缀,最终返回id(图像id#问题id),index(问题索引),question(问题,其中前缀是question),target(答案),passages(显式知识+隐式知识,前缀为context))
3 class2:OKvqaCollator(对数据进行了编码,最终返回img_ids(图像id), index(顺序索引), target_ids(答案id), target_mask(答案掩码), passage_ids(段落id), passage_masks(段落掩码))