import torch
import divide
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable
from model import Net
epochs = 10 #迭代训练五次
cirterion = nn.CrossEntropyLoss() #定义损失函数
optimizer = optim.SGD(Net.parameters(), lr=0.0001, momentum=0.9) #定义优化器
for epoch in range(epochs):
running_loss = 0 #损失值
train_correct = 0 #分类样本正确的总数
train_total = 0 #分类样本总数
for i, data in enumerate(divide.train_loader, 0):
inputs, train_labels = data
inputs, labels = Variable(inputs), Variable(train_labels)
optimizer.zero_grad()
outputs = Net(inputs)
_, train_predicted = torch.max(outputs.data, 1)
itrain_correct += (train_predicted == labels.data).sum()
loss = cirterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
train_total += train_labels.size(0)
print('train %d epoch loss: %.3f acc: %.3f ' % (
epoch + 1, running_loss / train_total, 100 * train_correct / train_total))
print('finished training!')
运行:
train 1 epoch loss: 0.172 acc: 56.029
train 2 epoch loss: 0.163 acc: 62.057
train 3 epoch loss: 0.155 acc: 65.274
train 4 epoch loss: 0.149 acc: 68.006
train 5 epoch loss: 0.144 acc: 70.131
train 6 epoch loss: 0.139 acc: 71.749
train 7 epoch loss: 0.135 acc: 72.994
train 8 epoch loss: 0.130 acc: 74.360
train 9 epoch loss: 0.126 acc: 75.486
train 10 epoch loss: 0.122 acc: 76.520
finished training!
- 定义epoch:迭代训练的次数
- 定义cirterion、optimizer:损失函数、优化器
- 使用循环语句进行数据的迭代训练,一共迭代训练epoch次
- runing_loss:总损失值、train_correct:分类正确总数量、train_total:训练图片总数量
- enumerate返回两个值,i是下标,由enumerate()函数给定,data表示数据,数据包括Tensor(图像)与标签
- input存储Tensor,train_labels存储标签,将这两个参数设置为Variable,表示参数存储的值是随着训练不断更新的
- zero_grad()将梯度置零,将损失值关于权重的导数置零,每次训练不同的batch,将上一个batch反向传播训练的导数置零,也就是优化器权重的清除
- Tensor传入神经网络模型Net
- torch.max返回两个值,第一个值是tensor中每行最大值,第二个值是最大值的索引,最大值就是神经网络判断属于各种类别概率值中的最大概率值,索引就是这个概率值最大的类别
- 用train_predicted存储这个索引,即神经网络判断出的类别,概率不记录
- 如果这个类别与实际的类别相同即分类正确,我们将train_correct进行+1操作,记录分类正确图片总数量
- 将神经网络判断出的类别与正确类别传入loss中计算损失值,再进行loss的反向传播,使用step()进行优化器权重的更新
- 将损失值相加记录总损失值
- size(0)表示train_labels的第一维度,就是每次训练图片的数量:batch_size的大小,求和并由train_total记录,表示训练图片总数量
- 输出:running_loss / train_total为平均的损失值,train_correct / train_total为正确率