10分钟熟练掌握torchtext库

 
 我们知道做机器学习,在现有很多问题的解决方案和模型算法成熟的条件下,很大部分时间都花在数据清洗和预处理上面。之前做NLP问题的时候,数据预处理的整个工作都是自己手写的,代码保存着,等到下一次做其他NLP的任务时候把代码拿来改改。
      其实NLP的整个从   文本->Tenso r     的工作流程基本是固定,无非就是 preprocess->tokenizer->建立vocab->id化  差异只是在一些细节上的差异。最近用PyTroch做机器翻译,用到了torchtext 库,研究了一下发现这个库简洁好用而且很灵活。把NLP任务中的文本预处理流程都抽象出来了,然后标准化。所以我们用它处理文本的时候只要按流程来就行了。而掌握它只要10分钟左右,掌握它的设计理念,剩下就是用它了。
     我们知道任何一个数据集 由一条条 记录  构成的。每条记录由很多 字段(属性)构成。比如一个记录学生成绩的信息的CSV文件,name, age, sex, job
我们加载数据后 会对不同的字段做不同的数据(用过pandas的用户经常会用的一个操作就是 df[’col_name’].apply(func) )
然后把 处理后的数据 输入 Model 里面。
    torchtext就是这样的直接了当的思想处理文本数据的。下面我们就介绍tortext的几个 定义(类名词)
 
Example 中文名 样例,一个Example代表一条记录。
Field       字段(类型),一个Field定义了该字段的可能的处理操作。(主要是文本处理的一些操作)
我们看看Field的构造字段

都是文本处理中的一些操作。

 

如上 一条记录包含很多字段,一个Example包含很多Field,每个Field都一个名字。

我们看Example的 fromlist方法

data是一个列表, 存储了一条记录的 所有字段值。fields 则表示fields列表, fields中的field根data中的值是一一对应的。我们看到 每条记录在构建Example的时候 都会调用对应的field。需要说明的 是 fields 是一个 二元组列表。

每个二元组 第一个 表示的是 字段名,第二个就是对应的Field对象。看代码可以知道,Example的属性就是哪些field 名。

 

Dataset 则是很多一堆Example,我们看Dataset的构造函数,输入是examples和fields(跟Example的输入相同)。

 

 
Iterator   表示完 数据集了,下面就是遍历数据集了。神经网络的Model输入都是 mini batch的。所以就有了 Iterator工具。
 
Iterator定义了如何遍历数据集,他有一个split方法用来生成数据迭代器的。

 

我们用 batch.{field_name} 获取每个字段的批量值。

 

torchtext基本用法就这些,掌握这些概念就能熟练运用torchtext了,torchtext的代码也很精简,不懂的可以直接看源码。

 

下面是是一个torchtext实践,中英文于翻译语料 可以直接用于机器翻译模型训练的。需要安装spacy库

#!/usr/bin/env python
# encoding: utf-8

"""
@site: 
@software: PyCharm
@file: translation2019_dataset.py
"""
import os
import io
import json
import spacy
from spacy.lang.zh import Chinese
from typing import Tuple, List
from torchtext.data import Example, Field, Dataset, BucketIterator

ZH_CFG = {"pkuseg_model": "default", "require_pkuseg": True}
spacy_zh = Chinese(meta={"tokenizer": {"config": ZH_CFG}})
spacy_en = spacy.load('en_core_web_sm')


def get_translation2019_examples(data_dir, fields) -> Tuple[List[Example], List[Example], List[Example]]:
    train_data_file = os.path.join(data_dir, 'train.json')
    valid_data_file = os.path.join(data_dir, 'valid.json')
    test_data_file = os.path.join(data_dir, 'test.json')
    train_examples = []
    valid_examples = []
    test_examples = []
    with io.open(train_data_file, encoding='utf-8') as fd1, io.open(valid_data_file, encoding='utf-8') as fd2, io.open(
            test_data_file, encoding='utf-8') as fd3:
        for line in fd1:
            record = json.loads(line)
            en = record['english']
            zh = record['chinese']
            example = Example.fromlist([en, zh], fields)
            train_examples.append(example)
        for line in fd2:
            record = json.loads(line)
            en = record['english']
            zh = record['chinese']
            example = Example.fromlist([en, zh], fields)
            valid_examples.append(example)
        for line in fd3:
            record = json.loads(line)
            en = record['english']
            zh = record['chinese']
            example = Example.fromlist([en, zh], fields)
            test_examples.append(example)
    return train_examples, valid_examples, test_examples


def en_tokenize(text):
    return [str(token) for token in spacy_en.tokenizer(text)]


def zh_tokenize(text):
    return [str(token) for token in spacy_zh.tokenizer(text)]


def get_translation2019_dataset(data_dir):
    en_field = Field(init_token='<sos>', eos_token='<eos>', tokenize=en_tokenize)
    zh_field = Field(init_token='<sos>', eos_token='<eos>', tokenize=zh_tokenize)
    fields = [('en', en_field), ('zh', zh_field)]
    train_exampls, valid_examples, test_examples = get_translation2019_examples(data_dir, fields)
    train_dataset = Dataset(train_exampls, fields=fields)
    valid_dataset = Dataset(valid_examples, fields=fields)
    test_dataset = Dataset(test_examples, fields=fields)
    en_field.build_vocab(train_dataset, min_freq=2)
    zh_field.build_vocab(train_dataset, min_freq=2)
    train_iter, valid_iter, test_iter = BucketIterator.splits((train_dataset, valid_dataset, test_dataset),
                                                              batch_size=16)
    return train_iter, valid_iter, test_iter


def main():
    train_iter, valid_iter, test_iter = get_translation2019_dataset('translation2019zh')
    for i, bath in enumerate(train_iter):
        print(bath.en)
        print(bath.zh)


if __name__ == '__main__':
    main()

语料链接

 

链接: https://pan.baidu.com/s/1nqg8hjTWe4S4caelEHCFIg 提取码: btzv 

 

 

 

 

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值