HuggingFace TAPAS Model细节记录:
目前google的TAPAS model在huggingface上都能找到各种预训练model和训练好的model,所以这里就做一下model的尝试,以及finetune过程。
首先看一下最为常用的训练好的model:
hugging face TableQA model
这里面有大量的model,举例一下常用的:
google/tapas-base-finetuned-wtq
(download 19.2k)model
那这个和基础的TAPAS原论文model有啥区别呢?
-
version
首先是这个model有两个版本,区别在于position-embedding,一个用的是相对位置索引(在表格的每个cell开始时重置一下position index,默认是使用这个,效果会好一点),另一个是绝对位置索引。 -
intermediate pre-training (不感兴趣的可以忽略这个)
这个是训练模型的一部分,在基础的TAPAS的pre-training和fintune之间增加的一步,这个很多huggingface上很多基于tapas的model都用了。下面用的都是wikipedia的table和相关text,wikipedia的table保留至少两行两列(header加上就是三行),递归地将表格逐行拆分为上半部分和下半部分,直到它们最多有 50 个单元格。这样我们就获得了 370 万张表。具体来说,又分为两步:
-
Synthetic Statements
这部分是在wikipedia的table上合成句子和SQL,用于提高模型的数值运算处理和比较的能力。实际就是按照一定的模版生成与table相关的句子以及SQL,而且相关的句子以一定的概率进行扰乱。句子和SQL的语法规则如下:
<>尖括号包裹的是SQL语句,其余的是生成句子的短语,按照这种语法规则,生成的句子以及对应的SQL,并且在正常的句子的前面或者后面再加一层数字对比短语,这样的句子可能不会通顺,但是没关系,比如:
example table:Rank Player Country Earnings Events Wins 1 Greg Norman Australia 1,654,959 16 3 2 Billy Mayfair United States 1,543,192 28 2 3 Lee Janzen United States 1,378,966 28 3 4 Corey Pavin United States 1,340,079 22 2 5 Steve Elkington Australia 1,254,352 21 2 Australia 1,254,352 21 2 合成句子:
2 is less than
wins when Player is Lee Janzen.
SQL:SELECT wins FROM table WHERE player = “Lee Janzen”
结果:通过SQL产生的结果为3,2 is less than 3这是正确的,positive
负例句子:3 is less than
wins when Player is Lee Janzen. 这样的就是negative
像这样的生成了370万对,模型是将生成的句子和表格作为输入,输出是判断正错。 -
Counterfactual Statements
这里是对wikipedia表附近的文本进行entity replace,比如说,原句是:Greg Norman
has the highest earnings,我们替换成:Steve Elkington
has the highest earnings.
这样生成了410万句子,模型任务是判断句子是否被替换。
可能有些理解不是很正确,感兴趣的可以参考:
github描述
论文3.1 和3.2章节
-
fintune
依次通过 SQA, WikiSQL and finally WTQ.
code test
环境
- torch.version = ‘1.6.0+cu101’
- torch-scatter (这个包一定要安装):
pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html
这两个对应上就好,哪个版本无所谓,官网给的是1.8.0+cu101 - torch-sparse:
pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html
tqa = pipeline(task="table-question-answering", model="google/tapas-base-finetuned-wtq")
table = pd.DataFrame({
"fund":["CSI 300", "Bank of communications","China Shipping"],
"type":["income","hybrid","income"],
"annual increase":[0.15,0.18,0.16]})
table = table.astype(str)
table
query = ["Which funds its type is income?",
"Which fund has the highest annual return?",
"Is there a fund of income type?",
"What are the funds with an annual increase of more than 0.2?"