mxnet中,每个输入到RNN的batch 长度可不一样,但是mxnet中要求 batch内的长度是一样的。这里采用的方法是,利用 gluonnlp让每个batch内 seq的长度尽量一样。
假设进行的是 单输入的分类任务
import gluonnlp as nlp
from mxnet.gluon import data as gdata
'''
获取句子长度
'''
train_data_lengths=list()
for q in querys: # querys是list数据,是载入的序列数据,querys每个元素是一个样本,每个样本是一个list,list的元素是 单词的编号
train_data_lengths.append(len(q))
'''
准备数据
'''
train_gdata=gdata.ArrayDataset(querys, labels) # labels 是list数据,是类别标签,labels每个元素是一个样本的标签编号
'''
准备处理工具 batchify_fn
nlp.data.batchify.Tuple 把工具整合起来
nlp.data.batchify.Pad 处理 序列数据的pad
nlp.data.batchify.Stack 把标签数据整理成 mxnet.ndarray
'''
batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Pad(axis=0,pad_val=0), # 处理 序列数据的pad,得到的是mxnet.ndarray
nlp.data.batchify.Stack() # 处理标签数据,得到的是mxnet.ndarray
)
'''
准备batch采样工具 batch_sampler
'''
batch_size=16
# batch_sampler = nlp.data.sampler.FixedBucketSampler(train_data_lengths,batch_size=batch_size,num_buckets=10,ratio=0.5,shuffle=True) # 不一定能得到 大小为batch_size的batch
batch_sampler=nlp.data.sampler.SortedBucketSampler(train_data_lengths,batch_size=batch_size,mult=100,shuffle=True) # 一定能得到 大小为batch_size的batch
'''
准备DataLoader
'''
train_dataloader = gluon.data.DataLoader(train_gdata,batch_sampler=batch_sampler,batchify_fn=batchify_fn)
step=0
for query,label in train_dataloader:
print query.shape,label.shape
假设进行的是 多输入的生成任务
import gluonnlp as nlp
from mxnet.gluon import data as gdata
'''
获取句子长度 !需要注意, train_data_lengths的每个元素需要换成 tuple of int or list of int,表示一个样本中多个序列的长度
'''
train_data_lengths=list()
# querys、reply1s、reply2s、target_replies都是list数据,是序列数据,每个元素是一个样本,每个样本是一个list,list的元素是 单词的编号
for i in range(len(train_data.queries)):
length1=len(querys[i])
length2=len(reply1s[i])
length3=len(reply2s[i])
length4=len(target_replies[i])
train_data_lengths.append((length1, length2, length3, length4))
'''
准备数据
'''
train_gdata=gdata.ArrayDataset(querys, reply1s, reply2s, target_replies) # labels 是list数据,是类别标签,labels每个元素是一个样本的标签编号
'''
准备处理工具 batchify_fn
nlp.data.batchify.Tuple 把工具整合起来
nlp.data.batchify.Pad 处理 序列数据的pad
nlp.data.batchify.Stack 把标签数据整理成 mxnet.ndarray
'''
batchify_fn = nlp.data.batchify.Tuple( # 有多少个序列数据,就有多少个nlp.data.batchify.Pad
nlp.data.batchify.Pad(axis=0,pad_val=0), # 处理 序列数据的pad,得到的是mxnet.ndarray
nlp.data.batchify.Pad(axis=0,pad_val=0), # 处理 序列数据的pad,得到的是mxnet.ndarray
nlp.data.batchify.Pad(axis=0,pad_val=0), # 处理 序列数据的pad,得到的是mxnet.ndarray
nlp.data.batchify.Pad(axis=0,pad_val=0) # 处理 序列数据的pad,得到的是mxnet.ndarray
)
'''
准备batch采样工具 batch_sampler
'''
batch_size=16
# batch_sampler = nlp.data.sampler.FixedBucketSampler(train_data_lengths,batch_size=batch_size,num_buckets=10,ratio=0.5,shuffle=True) # 不一定能得到 大小为batch_size的batch
batch_sampler=nlp.data.sampler.SortedBucketSampler(train_data_lengths,batch_size=batch_size,mult=100,shuffle=True) # 一定能得到 大小为batch_size的batch
'''
准备DataLoader
'''
train_dataloader = gluon.data.DataLoader(train_gdata,batch_sampler=batch_sampler,batchify_fn=batchify_fn)
step=0
for query, reply1,reply2,target_reply in train_dataloader:
print query.shape, reply1.shape, reply2.shape, target_reply.shape