progressive-generation-master代码记录【下载处理数据】(主函数,划分训练集,验证集,测试集)

def main():
    dataset = CNNDataset()

    output_dir = 'data/cnn'
    os.makedirs(output_dir)

    tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')

    split_size = {
        'train': 10000,
        'dev': 5000,
        'test': 5000
    }

    train_texts = []
    for i in trange(20000, desc='Getting Train Text'):
        _, story_lines, _ = dataset[i]
        text = '\n\n'.join(story_lines)
        if len(tokenizer.tokenize(text)) > 1022:
            continue
        train_texts.append({'condition': '', 'text': text})

        if len(train_texts) >= split_size['train']:
            break

    print('#train:', len(train_texts))
    pickle.dump(train_texts, open(
        os.path.join(output_dir, 'train.pickle'), 'wb'))

    dev_texts = []
    for i in trange(20000, 30000, desc='Getting Dev Text'):
        _, story_lines, _ = dataset[i]
        text = '\n\n'.join(story_lines)
        if len(tokenizer.tokenize(text)) > 1022:
            continue
        dev_texts.append({'condition': '', 'text': text})

        if len(dev_texts) >= split_size['dev']:
            break

    print('#dev:', len(dev_texts))
    pickle.dump(dev_texts, open(
        os.path.join(output_dir, 'dev.pickle'), 'wb'))

    test_texts = []
    for i in trange(30000, 40000, desc='Getting Test Text'):
        _, story_lines, _ = dataset[i]
        text = '\n\n'.join(story_lines)
        if len(tokenizer.tokenize(text)) > 1022:
            continue
        test_texts.append({'condition': '', 'text': text})

        if len(test_texts) >= split_size['test']:
            break

    print('#test:', len(test_texts))
    pickle.dump(test_texts, open(
        os.path.join(output_dir, 'test.pickle'), 'wb'))

os.makedirs,创建一个cnn存放目录

tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')

选用分词器,这里使用gpt2的分词器(属于PretrainTokenizer)

   split_size = {
        'train': 10000,
        'dev': 5000,
        'test': 5000
    }

这里训练集,验证集,测试集的比例是2:1:1,对于超大型数据集可以做出更多种的分类变化。

 _, story_lines, _ = dataset[i]

这里前后都有两个无意义变量,是因为在之前的CNNData中的返回结果

return document_name, story_lines, summary_lines

 所以使用两个无意义变量来取到故事行

text = '\n\n'.join(story_lines)

将故事行中的变量以两个分行符分割赋给text,下为分割样式,可以看到其实是将故事中的每句话分割开来,方便后续处理

 使用       if len(tokenizer.tokenize(text)) > 1022:
            continue

如果text内的句子数量大于1022,则跳过本次循环,用以限制句子数量在1022以下的数据

train_texts.append({'condition': '', 'text': text})

将字典{'condition': '', 'text': text}添加到train_texts

'condition': ''预设为空

'text': 填充为上面的text,即为句子行

if len(train_texts) >= split_size['train']:
            break

设置结束条件 

print('#train:', len(train_texts))
    pickle.dump(train_texts, open(
        os.path.join(output_dir, 'train.pickle'), 'wb'))

输出train_texts 的数量,将train_texts以二进制编码存放入,output_dir路径,命名为train.pickle,'wb'表示以二进制写打开,如果文件不存在则创建存在则覆盖。

剩下的验证集与测试集为相同的原理,只有开头的区间不同,分别为0-20000,20000-30000,30000-40000.

——————————————————————————————————————————

分词器:bert第三篇:tokenizer_iterate7的博客-CSDN博客_bert tokenizer

join():将序列(也就是字符串、元组、列表、字典)中的元素以指定的字符连接生成一个新的字符串 

join()函数

语法:  'sep'.join(seq)

参数说明

sep:分隔符。可以为空

seq:要连接的元素序列、字符串、元组、字典

上面的语法即:以sep作为分隔符,将seq所有的元素合并成一个新的字符串

返回值:返回一个以分隔符sep连接各个元素后生成的字符串
———————————————————————————————————————————

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值