模型构建

基于百度飞桨的手写体数字识别实现(一)数据预处理

使用百度的飞桨框架,总结一些普遍的规律和此框架的简单使用方法

手写体数字识别任务

1. 整体流程

①数据处理:读取数据和预处理操作
②模型设计:网络结构(假设)
③训练配置:优化器(寻解算法)和计算资源配置
④训练过程:循环调用训练过程,前向计算+损失函数(优化目标)+后向传播
⑤保存模型:将训练好的模型保存
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gb7EQNaw-1597576573000)(框架.png )]

2. 读入数据和飞桨API查询方法

①加载飞桨和相关类库
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
import numpy as np
import os
from PIL import Image
②使用飞桨框架提供的Mnist数据集处理函数
③paddle.dataset.mnist.train()
④常见的学术数据集均有现成处理函数(查API可见)

查阅API的方法
a.搜索:在飞桨官网https://aistudio.baidu.com进行查阅
b.分类浏览:在飞桨的API功能分类中寻找
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WE09RAd2-1597576573002)(API.png )]

3. 模型设计、训练和测试

步骤
①声明实例
②加载参数
③灌入数据
④打印结果
注意点
①图片数据归一化
②正确设置路径
③模型“校验”状态

4. 处理数据

####分析数据集结构,并拆分训练集和测试集
处理数据的五大操作

  1. 读入数据
  2. 拆分样本集合(分成训练集、验证集和测试集)
  3. 训练样本集乱序
  4. 生成批次数据
  5. 校验数据有效性

完整处理流程和异步读取数据

  1. 训练样本集乱序
    a. 建立ID集 index_list
    b. 乱序 index_list
    c. 以新顺序读取数据
  2. 生成批次数据
    a. 设置batchsize
    b. 数据转变成符合要求的np.array格式
    c. Python生成器:yield,减少内存占用
  3. 校验数据的有效性
    a. 校验并刨除不合预期的数据
    b. 某些案例中图片数量和标签数量相同

异步读取VS同步读取
同步读取:IO和网络计算串行,速度慢
异步读取:IO和计算通过一个"异步队列"交互,IO把数据不停放入队列,网络计算不停的从队列取 数据,二者同时进行。

PyReader
飞桨提供的异步数据读取器,只需要修改两行代码
创建一个DataLoader对象用于加载Python生成器产生的数据,数据会由Python线程预先读取,并异步送入设定了容量上限的队列中。

#定义DataLoader对象用于加载Python生成器产生的数据
data_loader = fuild.io.DataLoader.from_generator(capacity=5,return_list=True)
#设置数据生成器
data_loader.set_batch_generator(train_loader,places=place)

5. 数据预处理实战演练

题目要求:查询API文档,写一个cifar-10数据集的数据读取器,并执行乱序,分批次读取,打印第一个batch数据的shape、类型信息。

#加载飞桨和相关数据处理的库
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
import numpy as np
import os
import random

#设置数据读取器,读取cifar-10数据训练集
trainset = paddle.dataset.cifar.train10()
#包装数据读取器,每次读取的数据数量设置为batch_size=5
train_reader = paddle.batch(trainset, batch_size=5)

#以迭代的形式读取数据
for batch_id,data in enumerate(train_reader()):
    #获取图像数据,并转为float32类型
    img_data = np.array([x[0] for x in data]).astype('float32')
    #获取图像标签数据,并转为float32类型
    label_data = np.array([x[1] for x in data]).astype('float32')
    #打印数据形状
    print("图像数据形状和对应数据为:", img_data.shape, img_data[0])
    print("图像标签形状和对应数据为:", label_data.shape, label_data[0])
    break
print("\n打印第一个batch的第一个图像,对应标签数字为{}".format(label_data[0]))
#打乱顺序
imgs_length = len(img_data)
#定义数据集每个数据的序号,根据序号读取数据
index_list = list(range(imgs_length))
random.shuffle(index_list)
imgs_list = []
for i in index_list:
    img = np.array(img_data[i]+1)*127.5
    img = np.reshape(img,[3,32,32]).astype(np.uint8)
    img = np.transpose(img,(1,2,0))
    imgs_list.append(img)
#显示第一个batch的第一个图像
import matplotlib.pyplot as plt
print(img.shape)
print(img.dtype.name)
plt.figure("This is the first picture of cifar10")
plt.imshow(img)
plt.axis('on')
plt.title('image')
plt.show()
(result.png )]
far10")
plt.imshow(img)
plt.axis('on')
plt.title('image')
plt.show()

输出结果为:
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值