LeNet的实现在这篇文章里,使用该网络结构进行训练手写数字,保存每个epoch下的权重。
import paddle
import numpy as np
import paddle.nn.functional as F
from paddle.vision.transforms import ToTensor
from paddle.vision.datasets import MNIST
from LeNet import LeNet
def train(model,opt,train_loader,valid_loader):
""" 定义训练过程 """
use_gpu = True
# paddle.set_device("gpu:0") if use_gpu else paddle.set_device('cpu')
print("start training...")
model.train() # 训练
for epoch in range(EPOCH_NUM):
for batch_id, data in enumerate(train_loader()):
img = data[0]
label = data[1]
# 计算模型输出
logits = model(img)
# 计算损失函数
loss_func = paddle.nn.CrossEntropyLoss(reduction='none')
loss = loss_func(logits,label)
avg_loss = paddle.mean(loss)
if batch_id%2000 == 0:
print(f"epoch: {epoch}, batch_id: {batch_id}, loss is: {float(avg_loss.numpy()):.4f}")
avg_loss.backward() # 后向传播
opt.step() # 梯度更新
opt.clear_grad() # clear梯度
model.eval() # 每轮训练之后进行测试
accuracies = []
losses = []
for batch_id,data in enumerate(valid_loader):
img = data[0]
label = data[1]
# 计算模型输出
logits = model(img)
pred = F.softmax(logits)
# 计算损失函数
loss_func = paddle.nn.CrossEntropyLoss(reduction='none')
loss = loss_func(logits,label)
acc = paddle.metric.accuracy(pred,label)
accuracies.append(acc.numpy())
losses.append(loss.numpy())
print(f"[validation] epoch: {epoch}, accuracy/loss: {np.mean(accuracies):.4f}/{np.mean(losses):.4f}")
# 进行下一轮训练
model.train()
# 保存模型参数
paddle.save(model.state_dict(), f'mnist_ep{epoch:03d}.pdparams')
if __name__ == '__main__':
EPOCH_NUM = 5
model = LeNet(num_classes=10)
# 设置优化器为Momentum
opt = paddle.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameters=model.parameters())
# 定义数据读取器
train_loader = paddle.io.DataLoader(MNIST(mode='train',transform=ToTensor()),
batch_size=10,shuffle=True)
valid_loader = paddle.io.DataLoader(MNIST(mode='test',transform=ToTensor()),
batch_size=10)
# 启动训练过程
train(model,opt,train_loader,valid_loader)
结果:
start training...
epoch: 0, batch_id: 0, loss is: 2.6200
epoch: 0, batch_id: 2000, loss is: 2.2494
epoch: 0, batch_id: 4000, loss is: 2.0231
[validation] epoch: 0, accuracy/loss: 0.7943/0.8004
epoch: 1, batch_id: 0, loss is: 0.9391
epoch: 1, batch_id: 2000, loss is: 0.5264
epoch: 1, batch_id: 4000, loss is: 0.3571
[validation] epoch: 1, accuracy/loss: 0.9157/0.2988
epoch: 2, batch_id: 0, loss is: 0.1449
epoch: 2, batch_id: 2000, loss is: 0.3414
epoch: 2, batch_id: 4000, loss is: 0.1088
[validation] epoch: 2, accuracy/loss: 0.9431/0.2011
epoch: 3, batch_id: 0, loss is: 0.3511
epoch: 3, batch_id: 2000, loss is: 0.0477
epoch: 3, batch_id: 4000, loss is: 0.0648
[validation] epoch: 3, accuracy/loss: 0.9555/0.1607
epoch: 4, batch_id: 0, loss is: 0.2642
epoch: 4, batch_id: 2000, loss is: 0.0261
epoch: 4, batch_id: 4000, loss is: 0.0641
[validation] epoch: 4, accuracy/loss: 0.9627/0.1284
Process finished with exit code 0