Transformer-pg-generator-master源码解析(1)

本文详细介绍了数据加载、词汇表构建、停用词处理、原始语料下载筛选、编码转换、生成迭代器、数据批量化等关键步骤,涉及one-hot编码、词汇表构建、停用词列表生成、数据截断等技术,为自然语言处理任务提供数据预处理的完整流程。
摘要由CSDN通过智能技术生成

data_load.py

_load_vocab函数(构建one-hot词典)

词典下载函数,定义vocab为list,打开vocabulary词典,加入每一个词.
enumerate() 函数:用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中.

seq = ['one', 'two', 'three']
>>> for i, element in enumerate(seq):
...     print i, element
...
0 one
1 two
2 three

for…in 语句用于遍历数组或者对象的属性for x in list
该函数返回两个字典,分别是token2idx和idx2token

def _load_vocab(vocab_fpath):
    '''Loads vocabulary file and returns idx<->token maps
    vocab_fpath: string. vocabulary file path.
    Note that these are reserved
    0: <pad>, 1: <unk>, 2: <s>, 3: </s>
    Returns
    two dictionaries.
    '''
    vocab = []
    with open(vocab_fpath, 'r', encoding='utf-8') as f:
        for line in f:
            vocab.append(line.replace('\n', ''))
    token2idx = {token: idx for idx, token in enumerate(vocab)} # 利用迭代器,返回索引与对应数据
    idx2token = {idx: token for idx, token in enumerate(vocab)}

    return token2idx, idx2token

load_stop函数(停用词下载)

def load_stop(vocab_path):
    """
    load stop word
    :param vocab_path: stop word path
    :return: stop word list
    """
    stop_words = []
    with open(vocab_path, 'r', encoding='utf-8') as f:
        for line in f:
            stop_words.append(line.replace('\n', ''))

    return sorted(stop_words, key=lambda i: len(i), reverse=True)

定义一个list,根据空格和下一行切分放入list.
语法: sorted(可迭代对象,key=函数(排序规则),reverse=(是否倒序,等于True就是倒序))
sorted新建了一个新的list
语法 :函数名 = lambda 参数 : 返回值
可迭代对象为停用词list,排序规则是停用词的长度,倒序排序。
把迭代器中的每一个值传给匿名函数,返回一个数字作为key
所以该函数返回了一个从大到小排序的停用词列表。

_load_data函数(下载筛选原始语料)

def _load_data(fpaths, maxlen1, maxlen2):
    '''Loads source and target data and filters out too lengthy samples.
    fpath1: source file path. string.源语句
    fpath2: target file path. string.目标语句
    maxlen1: source sent maximum length. scalar.标量
    maxlen2: target sent maximum length. scalar.

    Returns
    sents1: list of source sents
    sents2: list of target sents
    '''
    sents1, sents2 = [], []
    for fpath in fpaths.split('|'):
        with open(fpath, 'r', encoding='utf-8') as f:
            for line in f:
                splits = line.split(',')
                if len(splits) != 2: continue
                sen1 = splits[1].replace('\n', '').strip()
                sen2 = splits[0].strip()
                if len(list(sen1)) + 1 > maxlen1-2: continue
                if len(list(sen2)) + 1 > maxlen2-1: continue
                
                sents1.append(sen1.encode('utf-8'))
                sents2.append(sen2.encode('utf-8'))

    return sents1[:400000], sents2[:400000]

语法: split() 通过指定分隔符对字符串进行切片str.split(str="", num=string.count(str))
str – 分隔符,默认为所有的空字符,包括空格、换行(\n)、制表符(\t)等num – 分割次数。默认为 -1, 即分隔所有。该方法返回分割后的字符串列表。
将每一行按照句号进行分割如果不是两部分直接跳出 如果是则进行下面,两部份分别去空格,去换行符。如果源语句/目标语句大于设定的长度,直接过滤掉不加入到列表中。否则加入列表sent1、sent2。返回两个list源语句/目标语句.

_encode函数(编码)

def _encode(inp, token2idx, maxlen, type):
    '''Converts string to number. Used for `generator_fn`.
    inp: 1d byte array.
    type: "x" (source side) or "y" (target side)
    dict: token2idx dictionary
    Returns
    list of numbers
    '''
    inp = inp.decode('utf-8')
    if type == 'x':
        tokens = ['<s>'] + list(inp) + ['</s>']
        while len(tokens) < maxlen:
            tokens.append('<pad>')
        return [token2idx.get(token, token2idx['<unk>']) for token in tokens]

    else:
        inputs = ['<s>'] + list(inp)
        target = list(inp) + ['</s>']
        while len(target) < maxlen:
            inputs.append('<pad>')
            target.append('<pad>')
        return [token2idx.get(token, token2idx['<unk>']) for token in inputs], [token2idx.get(token, token2idx['<unk>']) for token in target]

inp是一个字符串,调用python中的decoder方法,decode() 方法以 encoding 指定的编码格式解码字符串。默认编码为字符串编码。如果是源语句,前后加开始符结束符。该方法返回解码后的字符串。
如果小于最大数值加
Python 字典(Dictionary) get() 函数返回指定键的值。
dict.get(key[, value])
key – 字典中要查找的键。
value – 可选,如果指定键的值不存在时,返回该默认值
该函数返回值为句子数字列表。(换句话说也就是句子的词向量化)

_generator_fn函数(创建一个迭代器)

def _generator_fn(sents1, sents2, vocab_fpath, maxlen1, maxlen2):
    '''Generates training / evaluation data
    sents1: list of source sents
    sents2: list of target sents
    vocab_fpath: string. vocabulary file path.

    yields
    xs: tuple of
        x: list of source token ids in a sent
        x_seqlen: int. sequence length of x
        sent1: str. raw source (=input) sentence
    labels: tuple of
        decoder_input: decoder_input: list of encoded decoder inputs
        y: list of target token ids in a sent
        y_seqlen: int. sequence length of y
        sent2: str. target sentence
    '''
    token2idx, _ = _load_vocab(vocab_fpath)
    for sent1, sent2 in zip(sents1, sents2):
        x = _encode(sent1, token2idx, maxlen1, "x")

        inputs, targets = _encode(sent2, token2idx, maxlen2, "y")

        yield (x, sent1.decode('utf-8')), (inputs, targets, sent2.decode('utf-8'))

zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的对象。
将遍历后的句子分别调用
yield语法一个带有 yield 的函数就是一个 generator,它和普通函数不同,生成一个 generator 看起来像函数调用,但不会执行任何函数代码,直到对其调用 next()(在 for 循环中会自动调用 next())才开始执行。虽然执行流程仍按函数的流程执行,但每执行到一个 yield 语句就会中断,并返回一个迭代值,下次执行时从 yield 的下一个语句继续执行。看起来就好像一个函数在正常执行的过程中被 yield 中断了数次,每次中断都会通过 yield 返回当前的迭代值。

_input_fn函数(将数据转化为tensor格式)

def _input_fn(sents1, sents2, vocab_fpath, batch_size, gpu_nums, maxlen1, maxlen2, shuffle=False):
    '''Batchify data批数据
    sents1: list of source sents
    sents2: list of target sents
    vocab_fpath: string. vocabulary file path.
    batch_size: scalar
    shuffle: boolean

    Returns
    xs: tuple of
        x: int32 tensor. (N, T1)
        x_seqlens: int32 tensor. (N,)
        sents1: str tensor. (N,)
    ys: tuple of
        decoder_input: int32 tensor. (N, T2)
        y: int32 tensor. (N, T2)
        y_seqlen: int32 tensor. (N, )
        sents2: str tensor. (N,)
    '''
    shapes = (([maxlen1], ()),
              ([maxlen2], [maxlen2], ()))
    types = ((tf.int32, tf.string),
             (tf.int32, tf.int32, tf.string))

    dataset = tf.data.Dataset.from_generator(
        _generator_fn,
        output_shapes=shapes,
        output_types=types,
        args=(sents1, sents2, vocab_fpath, maxlen1, maxlen2))  # <- arguments for generator_fn. converted to np string arrays

    if shuffle: # for training
        dataset = dataset.shuffle(128*batch_size*gpu_nums)

    dataset = dataset.repeat()  # iterate forever
    dataset = dataset.batch(batch_size*gpu_nums)

    return dataset

利用Tensorflow方法生成dataset.其中shapes,types为参数。生成后进行洗牌打乱。128batch_sizagpu_nums表示每次shffle缓存区的数量,然后从缓存区取出batch_size*gpu_nums数量作为每一个batch的数据个数。repeat(x)方法是重复构建次数。
x为参数,不设置则表示无限次重复。
该函数返回了一个Dateset

get_batch函数(获得训练和评估的mini-batches)

def get_batch(fpath, maxlen1, maxlen2, vocab_fpath, batch_size, gpu_nums, shuffle=False):
    '''Gets training / evaluation mini-batches
    fpath: source file path. string.
    maxlen1: source sent maximum length. scalar.
    maxlen2: target sent maximum length. scalar.
    vocab_fpath: string. vocabulary file path.
    batch_size: scalar
    shuffle: boolean

    Returns
    batches
    num_batches: number of mini-batches
    num_samples
    '''
    sents1, sents2 = _load_data(fpath, maxlen1, maxlen2)
    batches = _input_fn(sents1, sents2, vocab_fpath, batch_size, gpu_nums, maxlen1, maxlen2, shuffle=shuffle)
    num_batches = calc_num_batches(len(sents1), batch_size*gpu_nums)
    return batches, num_batches, len(sents1)

该函数用于获取批量数据集,返回值为btaches,批数据的个数,已经sent1的个数

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值