使用paddle.batch自定义读取数据函数

自定义读取数据函数时,想要使用paddle.batch对数据进行读取时,train_reader需要返回list类型的data,且data[i]的类型为tuple,具体数据如下

train_reader = paddle.batch(paddle.reader.shuffle(reader('train'), buf_size=5e4), batch_size=8)
for batch_id, data in enumerate(train_reader()):
    imgs = np.reshape((x[0] for x in data), [-1, 1, 28, 28]).astype('float32')
    labels = np.reshape((x[1] for x in data), [-1,1]).astype('float32')
    imgs = np.array(imgs)
    labels = np.array(labels)
# 每个批次的data,<class 'list'>,shape=(8, 2)
[(array([0.68235296, 0.6039216 , 0.65882355, ..., 0.4627451 , 0.47058824,
       0.44705883], dtype=float32), 7), (array([0.99215686, 0.9882353 , 0.99215686, ..., 0.8901961 , 0.9607843 ,
       0.9137255 ], dtype=float32), 7), (array([0.49019608, 0.5058824 , 0.5137255 , ..., 0.49411765, 0.49019608,
       0.5058824 ], dtype=float32), 2), (array([0.30980393, 0.36078432, 0.2       , ..., 0.5372549 , 0.5176471 ,
       0.50980395], dtype=float32), 7), (array([0.58431375, 0.5764706 , 0.5921569 , ..., 0.53333336, 0.52156866,
       0.5176471 ], dtype=float32), 4), (array([0.44705883, 0.43137255, 0.43529412, ..., 0.43529412, 0.49411765,
       0.4627451 ], dtype=float32), 4), (array([0.25490198, 0.5686275 , 0.6745098 , ..., 0.6       , 0.7607843 ,
       0.7764706 ], dtype=float32), 9), (array([0.54901963, 0.5803922 , 0.67058825, ..., 0.84313726, 0.8862745 ,
       0.89411765], dtype=float32), 5)]

# data[0],<class 'tuple'>,shape=(2,)
(array([0.68235296, 0.6039216 , 0.65882355, ..., 0.4627451 , 0.47058824, 0.44705883], dtype=float32), 7)

 自定义的数据集制作函数:

import numpy as np

a = [[1,2,3,4],[5,6,7,8]]
b = [[1],[2]]
d = []
for i in range(len(a)):
    c = ()
    ai = [a[i]]
    bi = [b[i]]
    ai = np.array(ai).astype('float32')
    bi = np.array(bi)
    c = tuple(ai) + tuple(bi)
    print(c)
    d.append(c)
print(d)
print(type(d))
print(d[0])
print(type(d[0]))

 运行结果如下,可以满足上述的返回数据要求:

(array([1., 2., 3., 4.], dtype=float32), array([1]))
(array([5., 6., 7., 8.], dtype=float32), array([2]))
[(array([1., 2., 3., 4.], dtype=float32), array([1])), (array([5., 6., 7., 8.], dtype=float32), array([2]))]
<class 'list'>
(array([1., 2., 3., 4.], dtype=float32), array([1]))
<class 'tuple'>

依靠上述数据制作方法,对加载MNIST的json数据文件进行如下处理,制作满足paddle.batch读取的数据集类型

def reader(mode='train'):
    datafile = './work/mnist.json.gz'
    print('loading mnist dataset from {} ......'.format(datafile))
    # 加载json数据文件
    data = json.load(gzip.open(datafile))
    print('mnist dataset load done')
    if mode == 'train':
        return trainset_reader
    elif mode == 'valid':
        return valset_reader
    elif mode == 'test':
        return evalset_reader  

def trainset_reader():
    train_set,_,_ = data
    # print(np.array(train_set).T.shape) # (50000,2)
    imgs, labels = train_set[0], train_set[1]
    trainset = []
    for i in range(len(imgs)):
        tup = ()
        img_i = [imgs[i]]
        label_i = [labels[i]]
        img_i = np.array(img_i).astype('float32')
        label_i = np.array(label_i)
        tup = tuple(img_i) + tuple(label_i)
        trainset.append(tup)
    return trainset

定义数据生成函数,对读取的MNIST训练数据进行变形和转换处理

def data_generator(train_data):
    IMG_ROWS = 28
    IMG_COLS = 28
    imgs = []
    labels = []
    for x in train_data:
        imgs.append(np.reshape(x[0], [-1,28,28]).astype('float32'))
        labels.append(np.reshape(x[1], [-1]).astype('float32'))
    imgs = np.array(imgs)
    labels = np.array(labels)
    imgs = fluid.dygraph.to_variable(imgs)
    labels = fluid.dygraph.to_variable(labels)
    return imgs, labels

使用上述自定义的数据读取函数后,仅需对训练过程中加载和读取数据的那两行代码进行替换,其他和手写数字模型代码相同

model = MNIST()
model.train()
BATCHSIZE = 100
train_reader = paddle.batch(paddle.fluid.io.shuffle(reader('train'), buf_size=5e4), batch_size=BATCHSIZE) # 替换前为train_loader = load_data('train')
optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.001, parameter_list=model.parameters())
for batch_id, d in enumerate(train_reader()):
    # 读取数据的方式更加简洁
    image, label = data_generator(d) 
    # 替换前为image_data, label_data = d
    # image = fluid.dygraph.to_variable(image_data)
    # label = fluid.dygraph.to_variable(label_data)

 

 

  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值