一、阅读理解基本介绍
四种广义的机器阅读理解任务:
- 完形填空
-
- 定义:给定文章C,将其中的一个词或者实体a(a属于C)隐去作为待填空的问题,完形填空任务要求通过最大化条件概率P(a|C-a)来利用正确的词或实体a进行填空。
- 数据集:CNN & Daily Mail、CBT、LAMBADA、Who-did-What、CLOTH、CliCR
- 多项选择
-
- 定义:给定文章C、问题Q和一系列候选答案集合,多项选择任务通过最大化条件概率来从候选答案集合A中挑选出正确答案回答问题Q。
- 数据集:MCTest、RACE
- 片段抽取
-
- 定义:给定文章C(其中包含 n 个词)和问题Q,片段抽取任务通过最大化条件概率P(a|C,Q)来从文章中抽取连续的子序列作为问题的正确答案。
- 数据集:SQuAD、NewsQA、TriviaQA、DuoRC
- 自由作答
-
- 定义:给定文章C和问题Q,自由作答的正确答案a有时可能不是文章C的子序列。自由作答任务通过最大化条件概率P(a|C,Q)来预测回答问题Q的正确答案a。
- 数据集:bAbI、MS MARCO、SearchQA、NarrativeQA、DuReader
本次比赛的场景是单项选择,即给定文章C、问题Q和一系列候选答案集合,单项选择任务通过最大化条件概率来从候选答案集合A中挑选出正确答案回答问题Q。
二、baseline介绍
本次baseline是https://mp.weixin.qq.com/s/CTMn4Mb25d0A1V8eZREiYw,不是本人,我也只是学习者。
比赛地址:https://www.biendata.xyz/competition/haihua_2021/
三、跑baseline中遇到的问题
1、bert问题
下载地址:https://github.com/ymcui/Chinese-BERT-wwm
但是地址中给的pytorch版本的放在了google云盘,我用的这个梯子不稳定,总是下一部分就断了,又得重新下,于是就想着把tensorflow版本的转化为pytorch版本的。
- 首先,下载好tensorfflow版本的bert
- 再者,导包
-
pip install transformers
运行该代码
-
import transformers.convert_bert_original_tf_checkpoint_to_pytorch as con con.convert_tf_checkpoint_to_pytorch( r'chinese_wwm_ext_L-12_H-768_A-12/bert_model.ckpt', r'chinese_wwm_ext_L-12_H-768_A-12/bert_config.json', r'chinese_wwm_ext_L-12_H-768_A-12/pytorch_bert.bin' )
-
-
得到pytorch版本bert
-
**************问题1解决********************
之后,就跑通了。
四、理解baseline代码中遇到的问题
1、读原始数据
原作者写该baseline的时候,是将json文件转化为csv文件,我就当练代码了,就直接用json文件读入,把代码整体过一遍。
这里主要的问题是读入的数据如何拼接,怎样的数据才能算是能喂入模型的一条数据。
原始数据的一条包含:content、question、choices、label。(注意:除了choices有多个,其他都是一个)
因此,我们处理后的一条数据为:([(content,question,choice1),(content,question,choice2),...],label),整体分为两部分,data和label,其中data再进行细分。
2、DataLoder的collate_fn属性
这里的输入data跟Dataset的数据处理密切相关。
class MyDataset(Dataset):
def __init__(self, data, label):
self.data = data
self.label = label
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
## data [(content,question,choice),(),...]
content = []
question = []
choice = []
for item in self.data[idx]:
content.append(item[0])
question.append(item[1])
choice.append(item[2])
if len(choice) < 4:
for i in range(4 - len(choice)):
content.append(content[len(choice) - 1])
question.append(question[len(choice) - 1])
choice.append('不知道')
question_choice = [q + ' ' + c for q, c in zip(question, choice)]
return content, question_choice, self.label[idx]
这里返回值有三个,那么重新的collate_fn函数中的data的维度为(batch_size,3)。
可以看到:
- x[0]为list,内容为[content,content,content,content],即4个一样的文本内容;
- x[1]为list,内容为[question+choice1,questin+choice2,question+choice3,questin+choice4],即4个一样的question分别和4个不同choice拼接。
- x[2]为scalar,内容为这道题的答案,即label。
def collate_fn(data): # 将文章问题选项拼在一起后,得到分词后的数字id,输出的size是(batch, n_choices, max_len)
input_ids, attention_mask, token_type_ids = [], [], []
for x in data:
text = tokenizer(x[1], text_pair=x[0], padding='max_length', truncation=True,
max_length=conf['max_len'],
return_tensors='pt')
input_ids.append(text['input_ids'].tolist())
attention_mask.append(text['attention_mask'].tolist())
token_type_ids.append(text['token_type_ids'].tolist())
input_ids = torch.tensor(input_ids)
attention_mask = torch.tensor(attention_mask)
token_type_ids = torch.tensor(token_type_ids)
label = torch.tensor([x[2] for x in data])
return input_ids, attention_mask, token_type_ids, label
五、学习到的知识
1、交叉验证
k-fold交叉验证:首先将全部样本划分成k个大小相等的样本子集;依次遍历这k个子集,每次把当前子集作为验证集,其余所有子集作为训练集,进行模型的训练和评估;最后把k次评估指标的平均值作为最终的评估指标。在实际实验中,k经常取10。
原作者在这里用的是StratifiedKFold,StratifiedKFold与k-fold类似,但取得数据集有区别,其分层采样交叉切分,确保训练集,测试集中各类别样本的比例与原始数据集中相同。
from sklearn.model_selection import StratifiedKFold
folds = StratifiedKFold(n_splits=conf['fold_num'], shuffle=True, random_state=conf['seed']).split(np.arange(len(X)), y)
## fold代表第几个fold
for fold, (trn_idx, val_idx) in enumerate(folds):
train_x = np.array(X)[trn_idx]
train_y = np.array(y)[trn_idx]
val_x = np.array(X)[val_idx]
val_y = np.array(y)[val_idx]
2、半精度训练
https://blog.csdn.net/HUSTHY/article/details/109485088
3、梯度累加
https://www.zhihu.com/question/303070254/answer/573037166
之前,由于数据量太大,显存不足,我们把数据切成一个个batch,而每训练一个batch,再反向传播之前,我们都会将梯度清零。
而梯度累加则是有种变相扩大batch_size的感觉,训练了一个batch后,不将梯度清零,到了自己设置的阈值,再进行梯度清零。
六、github个人主页
之后,会将自己过滤一遍的baseline放在我的主页,欢迎来访。