前言
本文所依赖环境如下,主要还是paddlepaddle和cifar数据集,cifar数据集在上一章讲过,就不重复了:
paddlepaddle(2.3.1)
cifar数据集
1.简单的自定义模型
话不多说,先放一个适配cifar数据集,最简单的自定义模型:
class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Sequential(
nn.Linear(3072, 10),
nn.Softmax())
def forward(self, image, label):
image = paddle.reshape(image, (1, -1))
return self.fc(image), label
可以看到这就是一个fc全连接层+一个softmax的结果。
全连接层的in配置的3072是因为cifar的图片大小是32*32*3,这里我们单个图片对其reshape到一维就是3072的长度。
接下来,先配置所需要的cifar数据集,代码如下:
import paddle
from paddle.vision.transforms import Normalize,Compose,Transpose,Resize
transform = Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], dat