data_load.py
_load_vocab函数(构建one-hot词典)
词典下载函数,定义vocab为list,打开vocabulary词典,加入每一个词.
enumerate() 函数:用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中.
seq = ['one', 'two', 'three']
>>> for i, element in enumerate(seq):
... print i, element
...
0 one
1 two
2 three
for…in 语句用于遍历数组或者对象的属性for x in list
该函数返回两个字典,分别是token2idx和idx2token
def _load_vocab(vocab_fpath):
'''Loads vocabulary file and returns idx<->token maps
vocab_fpath: string. vocabulary file path.
Note that these are reserved
0: <pad>, 1: <unk>, 2: <s>, 3: </s>
Returns
two dictionaries.
'''
vocab = []
with open(vocab_fpath, 'r', encoding='utf-8') as f:
for line in f:
vocab.append(line.replace('\n', ''))
token2idx = {token: idx for idx, token in enumerate(vocab)} # 利用迭代器,返回索引与对应数据
idx2token = {idx: token for idx, token in enumerate(vocab)}
return token2idx, idx2token
load_stop函数(停用词下载)
def load_stop(vocab_path):
"""
load stop word
:param vocab_path: stop word path
:return: stop word list
"""
stop_words = []
with open(vocab_path, 'r', encoding='utf-8') as f:
for line in f:
stop_words.append(line.replace('\n', ''))
return sorted(stop_words, key=lambda i: len(i), reverse=True)
定义一个list,根据空格和下一行切分放入list.
语法: sorted(可迭代对象,key=函数(排序规则),reverse=(是否倒序,等于True就是倒序))
sorted新建了一个新的list
语法 :函数名 = lambda 参数 : 返回值
可迭代对象为停用词list,排序规则是停用词的长度,倒序排序。
把迭代器中的每一个值传给匿名函数,返回一个数字作为key
所以该函数返回了一个从大到小排序的停用词列表。
_load_data函数(下载筛选原始语料)
def _load_data(fpaths, maxlen1, maxlen2):
'''Loads source and target data and filters out too lengthy samples.
fpath1: source file path. string.源语句
fpath2: target file path. string.目标语句
maxlen1: source sent maximum length. scalar.标量
maxlen2: target sent maximum length. scalar.
Returns
sents1: list of source sents
sents2: list of target sents
'''
sents1, sents2 = [], []
for fpath in fpaths.split('|'):
with open(fpath, 'r', encoding='utf-8') as f:
for line in f:
splits = line.split(',')
if len(splits) != 2: continue
sen1 = splits[1].replace('\n', '').strip()
sen2 = splits[0].strip()
if len(list(sen1)) + 1 > maxlen1-2: continue
if len(list(sen2)) + 1 > maxlen2-1: continue
sents1.append(sen1.encode('utf-8'))
sents2.append(sen2.encode('utf-8'))
return sents1[:400000], sents2[:400000]
语法: split() 通过指定分隔符对字符串进行切片str.split(str="", num=string.count(str))
str – 分隔符,默认为所有的空字符,包括空格、换行(\n)、制表符(\t)等num – 分割次数。默认为 -1, 即分隔所有。该方法返回分割后的字符串列表。
将每一行按照句号进行分割如果不是两部分直接跳出 如果是则进行下面,两部份分别去空格,去换行符。如果源语句/目标语句大于设定的长度,直接过滤掉不加入到列表中。否则加入列表sent1、sent2。返回两个list源语句/目标语句.
_encode函数(编码)
def _encode(inp, token2idx, maxlen, type):
'''Converts string to number. Used for `generator_fn`.
inp: 1d byte array.
type: "x" (source side) or "y" (target side)
dict: token2idx dictionary
Returns
list of numbers
'''
inp = inp.decode('utf-8')
if type == 'x':
tokens = ['<s>'] + list(inp) + ['</s>']
while len(tokens) < maxlen:
tokens.append('<pad>')
return [token2idx.get(token, token2idx['<unk>']) for token in tokens]
else:
inputs = ['<s>'] + list(inp)
target = list(inp) + ['</s>']
while len(target) < maxlen:
inputs.append('<pad>')
target.append('<pad>')
return [token2idx.get(token, token2idx['<unk>']) for token in inputs], [token2idx.get(token, token2idx['<unk>']) for token in target]
inp是一个字符串,调用python中的decoder方法,decode() 方法以 encoding 指定的编码格式解码字符串。默认编码为字符串编码。如果是源语句,前后加开始符结束符。该方法返回解码后的字符串。
如果小于最大数值加
Python 字典(Dictionary) get() 函数返回指定键的值。
dict.get(key[, value])
key – 字典中要查找的键。
value – 可选,如果指定键的值不存在时,返回该默认值
该函数返回值为句子数字列表。(换句话说也就是句子的词向量化)
_generator_fn函数(创建一个迭代器)
def _generator_fn(sents1, sents2, vocab_fpath, maxlen1, maxlen2):
'''Generates training / evaluation data
sents1: list of source sents
sents2: list of target sents
vocab_fpath: string. vocabulary file path.
yields
xs: tuple of
x: list of source token ids in a sent
x_seqlen: int. sequence length of x
sent1: str. raw source (=input) sentence
labels: tuple of
decoder_input: decoder_input: list of encoded decoder inputs
y: list of target token ids in a sent
y_seqlen: int. sequence length of y
sent2: str. target sentence
'''
token2idx, _ = _load_vocab(vocab_fpath)
for sent1, sent2 in zip(sents1, sents2):
x = _encode(sent1, token2idx, maxlen1, "x")
inputs, targets = _encode(sent2, token2idx, maxlen2, "y")
yield (x, sent1.decode('utf-8')), (inputs, targets, sent2.decode('utf-8'))
zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的对象。
将遍历后的句子分别调用
yield语法一个带有 yield 的函数就是一个 generator,它和普通函数不同,生成一个 generator 看起来像函数调用,但不会执行任何函数代码,直到对其调用 next()(在 for 循环中会自动调用 next())才开始执行。虽然执行流程仍按函数的流程执行,但每执行到一个 yield 语句就会中断,并返回一个迭代值,下次执行时从 yield 的下一个语句继续执行。看起来就好像一个函数在正常执行的过程中被 yield 中断了数次,每次中断都会通过 yield 返回当前的迭代值。
_input_fn函数(将数据转化为tensor格式)
def _input_fn(sents1, sents2, vocab_fpath, batch_size, gpu_nums, maxlen1, maxlen2, shuffle=False):
'''Batchify data批数据
sents1: list of source sents
sents2: list of target sents
vocab_fpath: string. vocabulary file path.
batch_size: scalar
shuffle: boolean
Returns
xs: tuple of
x: int32 tensor. (N, T1)
x_seqlens: int32 tensor. (N,)
sents1: str tensor. (N,)
ys: tuple of
decoder_input: int32 tensor. (N, T2)
y: int32 tensor. (N, T2)
y_seqlen: int32 tensor. (N, )
sents2: str tensor. (N,)
'''
shapes = (([maxlen1], ()),
([maxlen2], [maxlen2], ()))
types = ((tf.int32, tf.string),
(tf.int32, tf.int32, tf.string))
dataset = tf.data.Dataset.from_generator(
_generator_fn,
output_shapes=shapes,
output_types=types,
args=(sents1, sents2, vocab_fpath, maxlen1, maxlen2)) # <- arguments for generator_fn. converted to np string arrays
if shuffle: # for training
dataset = dataset.shuffle(128*batch_size*gpu_nums)
dataset = dataset.repeat() # iterate forever
dataset = dataset.batch(batch_size*gpu_nums)
return dataset
利用Tensorflow方法生成dataset.其中shapes,types为参数。生成后进行洗牌打乱。128batch_sizagpu_nums表示每次shffle缓存区的数量,然后从缓存区取出batch_size*gpu_nums数量作为每一个batch的数据个数。repeat(x)方法是重复构建次数。
x为参数,不设置则表示无限次重复。
该函数返回了一个Dateset
get_batch函数(获得训练和评估的mini-batches)
def get_batch(fpath, maxlen1, maxlen2, vocab_fpath, batch_size, gpu_nums, shuffle=False):
'''Gets training / evaluation mini-batches
fpath: source file path. string.
maxlen1: source sent maximum length. scalar.
maxlen2: target sent maximum length. scalar.
vocab_fpath: string. vocabulary file path.
batch_size: scalar
shuffle: boolean
Returns
batches
num_batches: number of mini-batches
num_samples
'''
sents1, sents2 = _load_data(fpath, maxlen1, maxlen2)
batches = _input_fn(sents1, sents2, vocab_fpath, batch_size, gpu_nums, maxlen1, maxlen2, shuffle=shuffle)
num_batches = calc_num_batches(len(sents1), batch_size*gpu_nums)
return batches, num_batches, len(sents1)
该函数用于获取批量数据集,返回值为btaches,批数据的个数,已经sent1的个数