def main():
dataset = CNNDataset()
output_dir = 'data/cnn'
os.makedirs(output_dir)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
split_size = {
'train': 10000,
'dev': 5000,
'test': 5000
}
train_texts = []
for i in trange(20000, desc='Getting Train Text'):
_, story_lines, _ = dataset[i]
text = '\n\n'.join(story_lines)
if len(tokenizer.tokenize(text)) > 1022:
continue
train_texts.append({'condition': '', 'text': text})
if len(train_texts) >= split_size['train']:
break
print('#train:', len(train_texts))
pickle.dump(train_texts, open(
os.path.join(output_dir, 'train.pickle'), 'wb'))
dev_texts = []
for i in trange(20000, 30000, desc='Getting Dev Text'):
_, story_lines, _ = dataset[i]
text = '\n\n'.join(story_lines)
if len(tokenizer.tokenize(text)) > 1022:
continue
dev_texts.append({'condition': '', 'text': text})
if len(dev_texts) >= split_size['dev']:
break
print('#dev:', len(dev_texts))
pickle.dump(dev_texts, open(
os.path.join(output_dir, 'dev.pickle'), 'wb'))
test_texts = []
for i in trange(30000, 40000, desc='Getting Test Text'):
_, story_lines, _ = dataset[i]
text = '\n\n'.join(story_lines)
if len(tokenizer.tokenize(text)) > 1022:
continue
test_texts.append({'condition': '', 'text': text})
if len(test_texts) >= split_size['test']:
break
print('#test:', len(test_texts))
pickle.dump(test_texts, open(
os.path.join(output_dir, 'test.pickle'), 'wb'))
os.makedirs,创建一个cnn存放目录
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
选用分词器,这里使用gpt2的分词器(属于PretrainTokenizer)
split_size = {
'train': 10000,
'dev': 5000,
'test': 5000
}
这里训练集,验证集,测试集的比例是2:1:1,对于超大型数据集可以做出更多种的分类变化。
_, story_lines, _ = dataset[i]
这里前后都有两个无意义变量,是因为在之前的CNNData中的返回结果
return document_name, story_lines, summary_lines
所以使用两个无意义变量来取到故事行
text = '\n\n'.join(story_lines)
将故事行中的变量以两个分行符分割赋给text,下为分割样式,可以看到其实是将故事中的每句话分割开来,方便后续处理
使用 if len(tokenizer.tokenize(text)) > 1022:
continue
如果text内的句子数量大于1022,则跳过本次循环,用以限制句子数量在1022以下的数据
train_texts.append({'condition': '', 'text': text})
将字典{'condition': '', 'text': text}添加到train_texts
'condition': ''预设为空
'text': 填充为上面的text,即为句子行
if len(train_texts) >= split_size['train']:
break
设置结束条件
print('#train:', len(train_texts))
pickle.dump(train_texts, open(
os.path.join(output_dir, 'train.pickle'), 'wb'))
输出train_texts 的数量,将train_texts以二进制编码存放入,output_dir路径,命名为train.pickle,'wb'表示以二进制写打开,如果文件不存在则创建存在则覆盖。
剩下的验证集与测试集为相同的原理,只有开头的区间不同,分别为0-20000,20000-30000,30000-40000.
——————————————————————————————————————————
分词器:bert第三篇:tokenizer_iterate7的博客-CSDN博客_bert tokenizer
join():将序列(也就是字符串、元组、列表、字典)中的元素以指定的字符连接生成一个新的字符串
join()函数
语法: 'sep'.join(seq)
参数说明
sep:分隔符。可以为空
seq:要连接的元素序列、字符串、元组、字典
上面的语法即:以sep作为分隔符,将seq所有的元素合并成一个新的字符串
返回值:返回一个以分隔符sep连接各个元素后生成的字符串
———————————————————————————————————————————