1.在data_utils.data中增加代码
def batch_flow_bucket(data,ws,batch_size,raw=False,add_end=True,
n_bucket=5,bucket_ind=1,debug=False):
#bucket_ind是指哪一个维度的输入作为bucket的依据
#n_bucket就是指把数据分成了多少个bucket
all_data=list(zip(*data))
lengths=sorted(list(set([len(x[bucket_ind]) for x in all_data])))
if n_bucket>len(lengths):
n_bucket=len(lengths)
splits=np.array(lengths)[
(np.linspace(0,1,5,endpoint=False)*len(lengths)).astype(int)
].tolist()
splits+=[np.inf] #np.inf无限大的正整数
if debug:
print(splits)
ind_data={}
for x in all_data:
l=len(x[bucket_ind])
for ind,s in enumerate(splits[:-1]):
if l>=s and l<=splits[ind +1]:
if ind not in ind_data:
ind_data[ind]=[]
ind_data[ind].append(x)
break
inds=sorted(list(ind_data.keys()))
ind_p=[len(ind_data[x])/len(all_data)for x in inds]
if debug:
print(np.sum(ind_p),ind_p)
if isinstance(ws,(list,tuple)):
assert len(ws)==len(data),"len(ws)必须等于len(data), ws是list或者是tuple"
if isinstance(add_end,bool):
add_end=[add_end]*len(data)
else:
assert(isinstance(add_end,(list,tuple))),"add_end不是boolean,就应该是一个list(tuple) "
assert len(add_end) ==len(data),"如果add_end是list(tuple),那么add_end的查那个度应该和输入数据长度是一致的"
mul=2
if raw:
mul=3
while True:
choice_ind=np.random.choice(inds,p=ind_p)
if debug:
print('choice_ind',choice_ind)
data_batch=random.sample(ind_data[choice_ind],batch_size)
batches=[[] for i in range(len(data)*mul)]
max_lens=[]
for j in range(len(data)):
max_len=max([
len(x[j]) if hasattr(x[j],'__len__') else 0
for x in data_batch
])+(1 if add_end[j] else 0)
max_lens.append(max_len)
for d in data_batch:
for j in range(len(data)):
if isinstance(ws,(list,tuple)):
w=ws[j]
else:
w=ws
#添加结尾
line=d[j]
if add_end[j] and isinstance(line,(tuple,list)):
line=list(line)+[WordSequence.END_TAG]
if w is not None:
x, xl=transform_sentence(line,w,max_lens[j],add_end[j])
batches[j*mul].append(x)
batches[j*mul+1].append(xl)
else:
batches[j * mul].append(line)
batches[j * mul + 1].append(line)
if raw:
batches[j*mul +2].append(line)
batches=[np.asarray(x) for x in batches]
yield batches
def test_batch_flow():
from fake_data import generate
x_data,y_data,ws_input,ws_target=generate(size=10000)
flow=batch_flow([x_data,y_data],[ws_input,ws_target],4)
x,xl,y,yl=next(flow)
print(x.shape, y.shape, xl.shape, yl.shape)
def test_batch_flow_bucket():
from fake_data import generate
x_data, y_data, ws_input, ws_target=generate(size=10000)
flow=batch_flow_bucket([x_data,y_data],[ws_input,ws_target],4,debug=True)
for _ in range(10):
x,xl,y,yl=next(flow)
print(x.shape,y.shape,xl.shape,yl.shape)
if __name__=='__main__':
# size=30000
# print(_get_embed_device(size))
#test_batch_flow()
test_batch_flow_bucket()
2.新建虚假数据Fake_data.py文件
import random
import numpy as np
from word_sequence import WordSequence
def generate(max_len=10,size=1000,same_len=False,seed=0):
"""生成虚假的数据"""
dictionary={
'a':'1',
'b':'2',
'c':'3',
'd':'4',
'aa':'1',
'bb':'2',
'cc':'3',
'dd':'4',
'aaa':'1'
}
if seed is not None:
random.seed(seed)
input_list=sorted(list(dictionary.keys()))
x_data=[]
y_data=[]
for x in range(size):
a_len=int(random.random()*max_len)+1
x=[]
y=[]
for _ in range(a_len):
word=input_list[int(random.random()*len(input_list))]
x.append(word)
y.append(dictionary[word])
if not same_len:
if y[-1]=='2':
y.append('2')
elif y[-1]=='3':
y.append('3')
y.append('4')
x_data.append(x)
y_data.append(y)
ws_input=WordSequence()
ws_input.fit(x_data)
ws_target=WordSequence()
ws_target.fit(y_data)
return x_data,y_data,ws_input,ws_target
def test():
x_data,y_data,ws_input,ws_target=generate()
print(len(x_data))
assert len(x_data)==1000
print(len(y_data))
assert len(y_data)==1000
print(np.max([len(x) for x in x_data]))
assert np.max([len(x) for x in x_data])==10
print(len(ws_input))
assert len(ws_input)==14
print(len(ws_target))
if __name__=='__main__':
test()