前言
Bert (Bi-directional Encoder Representations from Transformers) Pytorch 版本源码解读的第三篇,也是最后一部分。这一部分为源码中, wiki_dataset.py 文件中的内容,主要实现了 Bert 模型预训练时,数据的预处理工作。读完这一部分源码有助于更好的理解模型的输入部分的数据是如何构造的。
Bert 源码解读:
1. 模型结构源码: bert_model.py
2. 模型预训练源码:bert_training.py
3. 数据预处理源码:wiki_dataset.py
开始
1.初始化
class BERTDataset(Dataset):
def __init__(self, corpus_path, word2idx_path, seq_len, hidden_dim=384, on_memory=True):
# hidden dimension for positional encoding
self.hidden_dim = hidden_dim
# define path of dicts
self.word2idx_path = word2idx_path
# define max length
self.seq_len = seq_len
# load whole corpus at once or not
self.on_memory = on_memory
# directory of corpus dataset
self.corpus_path = corpus_path
# define special symbols
self.pad_index = 0
self.unk_index = 1
self.cls_index = 2
self.sep_index = 3
self.mask_index = 4
self.num_index = 5
# 加载字典
with open(word2idx_path, "r", encoding="utf-8") as f:
self.word2idx = json.load(f)
# 加载语料
with open(corpus_path, "r", encoding="utf-8") as f:
if not on_memory: