本文以 XO 图像集为例,使用 torch 实现简单图像分类。.
数据集网址:https://www.optophysiology.uni-freiburg.de/Research/research_DL/CNNsWithMatlabAndCaffe
可能是网址错了吧,找不到这个页面。。。我把数据集放在最后有兴趣的可以浅浅下载一下。
Let’s do it!
定义模型结构
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 9, 3) # 输入通道数:1,输出通道数:9,卷积核大小:3*3
self.maxpool = nn.MaxPool2d(2, 2) # 最大池化,2*2池化
self.conv2 = nn.Conv2d(9, 5, 3) # 输入通道数:9,输出通道数:1,卷积核大小:3*3
self.relu = nn.ReLU() # relu函数,非线性函数
self.fc1 = nn.Linear(27 * 27 * 5, 1200) # [((116-2)/2-2)/2]=27
self.fc2 = nn.Linear(1200, 64)
self.fc3 = nn.Linear(64, 2)
def forward(self, x):
x = self.maxpool(self.relu(self.conv1(x)))
x = self.maxpool(self.relu(self.conv2(x)))
x = x.view(-1, 27 * 27 * 5)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
这时卷积神经网络的效果图,后面的是多层感知机上的一些东西,不是讨论的重点。
训练模型
model = Net()
criterion = torch.nn.CrossEntropyLoss() # 损失函数 交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=0.1) # 优化函数:随机梯度下降
# 数据集加载
data_loader = DataLoader(
dataset=datasets.ImageFolder(
root='training_data_sm',
transform=transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor()
])
),
batch_size=64,
shuffle=True
)
epochs = 10
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(data_loader):
images, label = data
out = model(images)
loss = criterion(out, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
if (i + 1) % 10 == 0:
print('[%d %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0
print('finished train')
# 保存模型
torch.save(model, 'model_name.pth') # 保存的是模型, 不止是w和b权重值
由于torch框架已经发展成熟,因此torch会将反向传播给计算好,这时代码看起来就会和之前的MLP是差不多的。
测试模型
# 读取模型
model_load = torch.load('model_name.pth')
correct = 0
total = 0
with torch.no_grad(): # 进行评测的时候网络不更新梯度
for data in data_loader: # 读取测试集
images, labels = data
outputs = model_load(images)
_, predicted = torch.max(outputs.data, 1) # 取出 最大值的索引 作为 分类结果
total += labels.size(0) # labels 的长度
correct += (predicted == labels).sum().item() # 预测正确的数目
print('Accuracy of the network on the test images: %f %%' % (100. * correct / total))
查看特征图
# 看看每层的 卷积核 长相,特征图 长相
# 获取网络结构的特征矩阵并可视化
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL