为了小论文之跟着李沐学AI(二十一)

学习一下数据的预处理

d2l.DATA_HUB['fra-eng'] = (d2l.DATA_URL + 'fra-eng.zip',
                           '94646ad1522d915e7b0f9296181140edcf86a4f5')

#@save
def read_data_nmt():
    """载入“英语-法语”数据集。"""
    data_dir = d2l.download_extract('fra-eng')
    with open(os.path.join(data_dir, 'fra.txt'), 'r',
             encoding='utf-8') as f:
        return f.read()

raw_text = read_data_nmt()
print(raw_text[:75])

其实这个函数都多余,意思就是,下载一个数据集。
在这里插入图片描述
这是一个字符数据集,就是每一个元素都是一个字符

def preprocess_nmt(text):
    """预处理“英语-法语”数据集。"""
    def no_space(char, prev_char):
        return char in set(',.!?') and prev_char != ' '

    # 使用空格替换不间断空格
    # 使用小写字母替换大写字母
    text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
    # 在单词和标点符号之间插入空格
    out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
           for i, char in enumerate(text)]
    return ''.join(out)

text = preprocess_nmt(raw_text)
print(text[:80])

这个函数更像是一个标准化的函数,吧这个输入格式化了每一个单词之后有一个标点符号,单词与单词之间有一个制表符

def tokenize_nmt(text, num_examples=None):
    """词元化“英语-法语”数据数据集。"""
    source, target = [], []
    for i, line in enumerate(text.split('\n')):
        if num_examples and i > num_examples:
            break
        parts = line.split('\t')
        if len(parts) == 2:
            source.append(parts[0].split(' '))
            target.append(parts[1].split(' '))
    return source, target

source, target = tokenize_nmt(text)
source[:6], target[:6]

词元化,这里它输出了两个此表。一个是source可以理解为英文词典,一个是target可以理解为法语词典。我们想我们的输入经过RNN之后,可以返回target出来
词典的每一个元素,都是两个长度,一个是表示单词,一个是标识字符的标点符号

src_vocab = d2l.Vocab(source, min_freq=2,
                      reserved_tokens=['<pad>', '<bos>', '<eos>'])
len(src_vocab)

在这里,我们创建字典的时候,还加入了三个特殊的符号,分别标识开始符号,bos,终止符号eos,间隔符号pad,并且我们去除了出现频率小于2的字符。

def truncate_pad(line, num_steps, padding_token):
    """截断或填充文本序列。"""
    if len(line) > num_steps:
        return line[:num_steps]  # 截断
    return line + [padding_token] * (num_steps - len(line))  # 填充

truncate_pad(src_vocab[source[0]], 10, src_vocab['<pad>'])

对于数据,我们希望它的一个batch中的语句长度都能蛮num_step的效果,所以如果长度不足,我们需要用给它补齐,如果长度过长需要截取。

def build_array_nmt(lines, vocab, num_steps):
    """将机器翻译的文本序列转换成小批量。"""
    lines = [vocab[l] for l in lines]
    lines = [l + [vocab['<eos>']] for l in lines]
    array = torch.tensor([truncate_pad(
        l, num_steps, vocab['<pad>']) for l in lines])
    valid_len = (array != vocab['<pad>']).type(torch.int32).sum(1)
    return array, valid_len

对于这个函数需要好好理解一下
首先,我们先将lines变化一下
本来Lines = [[go, .],[hello,?]…]这样的形式的
现在把里面的字符变成对应的index,(我们的vocab是按照字符来分类的)
第二步,在每一个单词的背后加上eos这个符号,
第三步,补齐每一个Minibatch,把batch中的每一个单词都补齐成num_stemps的长度,
第四步,计算每一句话的有效长度。

def load_data_nmt(batch_size, num_steps, num_examples=600):
    """返回翻译数据集的迭代器和词汇表。"""
    text = preprocess_nmt(read_data_nmt())
    source, target = tokenize_nmt(text, num_examples)
    src_vocab = d2l.Vocab(source, min_freq=2,
                          reserved_tokens=['<pad>', '<bos>', '<eos>'])
    tgt_vocab = d2l.Vocab(target, min_freq=2,
                          reserved_tokens=['<pad>', '<bos>', '<eos>'])
    src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
    tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
    data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)
    data_iter = d2l.load_array(data_arrays, batch_size)
    return data_iter, src_vocab, tgt_vocab

这个函数就是实现一个封装.
num_steps表示你一个RNN循想处理几个时间步
batch_size:表示你一次性想读取几条数据出来
对于返回值 data_iter,其实是一个四元组,前面两个元素表示,batch个英文句子和对应的长度,后面两个元素表示batch个法文和对应的有效长度

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值