一、准备数据
1、建立样本数据读取路径与样本标签之间的关系。
2、构造读取器与数据预处理。自定义数据读取器,继承PaddlePaddle2.0的dataset类,在__getitem__方法中把自定义的预处理方法加载进去。
二、建立模型
先选用比较成熟的基础模型,看看基础模型所能够达到的准确度。之后再试试模型融合,准确度是否有提升。
三、应用高阶API训练模型
1、定义输入数据形状大小和数据类型。
2、实例化模型。如果要用高阶API,需要用Paddle.Model()对模型进行封装,如model = paddle.Model(model,inputs=input_define,labels=label_define)。
3、定义优化器。这个使用Adam优化器,学习率设置为0.0001,优化器中的学习率(learning_rate)参数很重要。要是训练过程中得到的准确率呈震荡状态,忽大忽小,可以试试进一步把学习率调低。
4、准备模型。这里用到高阶API,model.prepare()。
5、训练模型。这里用到高阶API,model.fit()。
四、应用已经训练好的模型进行预测
1、构建数据读取器。因为预测数据集没有标签,该读取器写法和训练数据读取器不一样,建议重新写一个类,继承于Dataset基类。
2、实例化模型。如果要用高阶API,需要用Paddle.Model()对模型进行封装,如paddle.Model(MyNet(),inputs=input_define),由于是预测模型,所以仅设定输入数据格式就好了。
3、读取刚刚训练好的参数。这个保存在/home/aistudio/work目录之下,如果指定的是final则是最后一轮训练后的结果。可以指定其他轮次的结果,比如model.load('/home/aistudio/work/30'),这里用到了高阶API,model.load()
4、准备模型。这里用到高阶API,model.prepare()。
5、读取待预测集合中的数据,利用已经训练好的模型进行预测。
6、结果保存。