pytorch 图像分类完整流程

简介:本篇文章展示pytorch做图像分类的完整过程。因为在我的应用场景下图片特征简单,对计算速度有要求,所以把网络模型写得很小(当然最终的模型要保密啦),加入了SPPnet对输入的图片尺寸没有要求。

我的训练数据集结构如下:
在这里插入图片描述
数据集划分参考


pytorch图像分类完整流程如下

  • 导入依赖库
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import torchvision
from PIL import Image, ImageFile
import matplotlib.pyplot as plt
  • 模型,带SPPnet对输入图像尺寸没有要求
class net(nn.Module):
    def __init__(self,channels=3, height=128, width=128, numLevels=3):
        super(net, self).__init__()#父类初始化
        self.numLevels = numLevels
        self.conv1 = nn.Conv2d(3,16,3)
        self.pool = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(16, 64, 3)
        self.fc1 = nn.Linear(896,64)#3层的SPPnet决定的特征数,与图片尺寸无关
        self.fc2 = nn.Linear(64,2)
        
    def SPPLayer(self,x):
        num, c, h, w = x.size() # num:样本数量 c:通道数 h:高 w:宽
        for i in range(self.numLevels):
            level = i+1
            kernel_size = (math.ceil(h / level), math.ceil(w / level))
            stride = (math.ceil(h / level), math.ceil(w / level))
            pooling = (math.floor((kernel_size[0]*level-h+1)/2), math.floor((kernel_size[1]*level-w+1)/2))
            tensor = F.max_pool2d(x, kernel_size=kernel_size, stride=stride, padding=pooling).view(num, -1)
           
            # 展开、拼接
            if (i == 0):
                x_flatten = tensor.view(num, -1)
            else:
                x_flatten = torch.cat((x_flatten, tensor.view(num, -1)), 1)
        return x_flatten
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = self.SPPLayer(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
        
  • 测试模型是否正确,如果正常输出则表示模型结构正确
image_w,image_h = [125,127]
model  = net(3,image_w,image_h)
x = torch.ones(1,3,image_w,image_h) 
#model.eval()
y = model(x)
y.size()
torch.Size([1, 2])
  • 定义验证图片是否正常的函数
def check_image(path):
    try:
        im = Image.open(path)
        return True
    except:
        return False
  • 定义图像预处理操作集
img_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.5,0.5,0.5],
                        std=[0.3,0.3,0.3])
])
  • 定义训练、测试、验证集
train_data_path = r"K:\imageData\polarity\data3\train"
train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=img_transforms,is_valid_file=check_image)
#test_data_path = r"K:\imageData\polarity\data2\test"
#test_data = torchvision.datasets.ImageFolder(root=test_data_path,transform=img_transforms,is_valid_file=check_image)
val_data_path = r"K:\imageData\polarity\data3\val"
val_data = torchvision.datasets.ImageFolder(root=val_data_path,transform=img_transforms,is_valid_file=check_image)
  • 定义数据加载器
batch_size = 32
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
#test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
  • 定义训练过程
def train(model, optimizer, loss_fn, train_loader, val_loader, epoches=30, device=torch.device("cpu")):
    train_loss_list = []
    valid_loss_list = []
    valid_accuracy_list = []
    epoch_list = []
    for epoch in range(1,epoches+1):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs,targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        
        model.eval()
        num_correct = 0
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            outputs = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(outputs, targets)
            valid_loss += loss.data.item() * inputs.size(0)
            correct = torch.eq(torch.max(F.softmax(outputs, dim=1), dim=1)[1],targets)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)
        valid_accuracy = num_correct / num_examples
        
        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,valid_loss, num_correct / num_examples))
        
        train_loss_list.append(training_loss)
        valid_loss_list.append(valid_loss)
        valid_accuracy_list.append(valid_accuracy)
        epoch_list.append(epoch)
        
    return train_loss_list,valid_loss_list,valid_accuracy_list, epoch_list
  • 定义损失函数、优化器、运行平台
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
loss_fn = torch.nn.CrossEntropyLoss()

if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")

model.to(device)
net(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=896, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=2, bias=True)
)
  • 查看模型参数数量
def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}
get_parameter_number(model)
{'Total': 67266, 'Trainable': 67266}
  • 训练
train_loss_list,valid_loss_list,valid_accuracy_list ,epoch_list = \
   train(model,optimizer, loss_fn, train_data_loader, val_data_loader,epoches=100, device=device)
Epoch: 1, Training Loss: 0.76, Validation Loss: 0.71, accuracy = 0.44
Epoch: 2, Training Loss: 0.70, Validation Loss: 0.69, accuracy = 0.44
Epoch: 3, Training Loss: 0.68, Validation Loss: 0.68, accuracy = 0.52
Epoch: 4, Training Loss: 0.68, Validation Loss: 0.67, accuracy = 0.72
Epoch: 5, Training Loss: 0.67, Validation Loss: 0.65, accuracy = 0.70
Epoch: 6, Training Loss: 0.66, Validation Loss: 0.64, accuracy = 0.70
Epoch: 7, Training Loss: 0.66, Validation Loss: 0.62, accuracy = 0.72
Epoch: 8, Training Loss: 0.65, Validation Loss: 0.61, accuracy = 0.72
Epoch: 9, Training Loss: 0.64, Validation Loss: 0.59, accuracy = 0.68
Epoch: 10, Training Loss: 0.63, Validation Loss: 0.58, accuracy = 0.70
Epoch: 11, Training Loss: 0.61, Validation Loss: 0.56, accuracy = 0.68
Epoch: 12, Training Loss: 0.60, Validation Loss: 0.55, accuracy = 0.68
Epoch: 13, Training Loss: 0.59, Validation Loss: 0.54, accuracy = 0.68
Epoch: 14, Training Loss: 0.58, Validation Loss: 0.53, accuracy = 0.66
Epoch: 15, Training Loss: 0.58, Validation Loss: 0.52, accuracy = 0.66
Epoch: 16, Training Loss: 0.57, Validation Loss: 0.51, accuracy = 0.68
Epoch: 17, Training Loss: 0.56, Validation Loss: 0.50, accuracy = 0.68
Epoch: 18, Training Loss: 0.56, Validation Loss: 0.50, accuracy = 0.68
Epoch: 19, Training Loss: 0.55, Validation Loss: 0.49, accuracy = 0.72
Epoch: 20, Training Loss: 0.55, Validation Loss: 0.49, accuracy = 0.76
Epoch: 21, Training Loss: 0.55, Validation Loss: 0.49, accuracy = 0.76
Epoch: 22, Training Loss: 0.54, Validation Loss: 0.48, accuracy = 0.78
Epoch: 23, Training Loss: 0.54, Validation Loss: 0.48, accuracy = 0.78
Epoch: 24, Training Loss: 0.53, Validation Loss: 0.47, accuracy = 0.78
Epoch: 25, Training Loss: 0.53, Validation Loss: 0.47, accuracy = 0.78
Epoch: 26, Training Loss: 0.53, Validation Loss: 0.47, accuracy = 0.78
Epoch: 27, Training Loss: 0.53, Validation Loss: 0.47, accuracy = 0.78
Epoch: 28, Training Loss: 0.52, Validation Loss: 0.46, accuracy = 0.78
Epoch: 29, Training Loss: 0.52, Validation Loss: 0.46, accuracy = 0.78
Epoch: 30, Training Loss: 0.52, Validation Loss: 0.46, accuracy = 0.78
Epoch: 31, Training Loss: 0.51, Validation Loss: 0.45, accuracy = 0.78
Epoch: 32, Training Loss: 0.51, Validation Loss: 0.45, accuracy = 0.78
Epoch: 33, Training Loss: 0.51, Validation Loss: 0.45, accuracy = 0.78
Epoch: 34, Training Loss: 0.50, Validation Loss: 0.45, accuracy = 0.78
Epoch: 35, Training Loss: 0.50, Validation Loss: 0.44, accuracy = 0.78
Epoch: 36, Training Loss: 0.50, Validation Loss: 0.44, accuracy = 0.80
Epoch: 37, Training Loss: 0.49, Validation Loss: 0.44, accuracy = 0.80
Epoch: 38, Training Loss: 0.49, Validation Loss: 0.44, accuracy = 0.80
Epoch: 39, Training Loss: 0.49, Validation Loss: 0.43, accuracy = 0.82
Epoch: 40, Training Loss: 0.48, Validation Loss: 0.43, accuracy = 0.82
Epoch: 41, Training Loss: 0.48, Validation Loss: 0.43, accuracy = 0.82
Epoch: 42, Training Loss: 0.48, Validation Loss: 0.42, accuracy = 0.82
Epoch: 43, Training Loss: 0.47, Validation Loss: 0.42, accuracy = 0.82
Epoch: 44, Training Loss: 0.47, Validation Loss: 0.42, accuracy = 0.82
Epoch: 45, Training Loss: 0.46, Validation Loss: 0.41, accuracy = 0.82
Epoch: 46, Training Loss: 0.46, Validation Loss: 0.41, accuracy = 0.82
Epoch: 47, Training Loss: 0.46, Validation Loss: 0.41, accuracy = 0.82
Epoch: 48, Training Loss: 0.45, Validation Loss: 0.40, accuracy = 0.82
Epoch: 49, Training Loss: 0.45, Validation Loss: 0.40, accuracy = 0.86
Epoch: 50, Training Loss: 0.44, Validation Loss: 0.40, accuracy = 0.88
Epoch: 51, Training Loss: 0.44, Validation Loss: 0.39, accuracy = 0.88
Epoch: 52, Training Loss: 0.43, Validation Loss: 0.39, accuracy = 0.88
Epoch: 53, Training Loss: 0.43, Validation Loss: 0.38, accuracy = 0.88
Epoch: 54, Training Loss: 0.42, Validation Loss: 0.38, accuracy = 0.88
Epoch: 55, Training Loss: 0.42, Validation Loss: 0.38, accuracy = 0.88
Epoch: 56, Training Loss: 0.41, Validation Loss: 0.37, accuracy = 0.88
Epoch: 57, Training Loss: 0.41, Validation Loss: 0.37, accuracy = 0.88
Epoch: 58, Training Loss: 0.40, Validation Loss: 0.36, accuracy = 0.88
Epoch: 59, Training Loss: 0.40, Validation Loss: 0.36, accuracy = 0.88
Epoch: 60, Training Loss: 0.39, Validation Loss: 0.36, accuracy = 0.88
Epoch: 61, Training Loss: 0.39, Validation Loss: 0.35, accuracy = 0.90
Epoch: 62, Training Loss: 0.38, Validation Loss: 0.35, accuracy = 0.90
Epoch: 63, Training Loss: 0.38, Validation Loss: 0.34, accuracy = 0.90
Epoch: 64, Training Loss: 0.37, Validation Loss: 0.34, accuracy = 0.90
Epoch: 65, Training Loss: 0.37, Validation Loss: 0.33, accuracy = 0.90
Epoch: 66, Training Loss: 0.36, Validation Loss: 0.33, accuracy = 0.92
Epoch: 67, Training Loss: 0.35, Validation Loss: 0.32, accuracy = 0.92
Epoch: 68, Training Loss: 0.35, Validation Loss: 0.32, accuracy = 0.92
Epoch: 69, Training Loss: 0.34, Validation Loss: 0.31, accuracy = 0.92
Epoch: 70, Training Loss: 0.33, Validation Loss: 0.31, accuracy = 0.94
Epoch: 71, Training Loss: 0.33, Validation Loss: 0.30, accuracy = 0.94
Epoch: 72, Training Loss: 0.32, Validation Loss: 0.30, accuracy = 0.94
Epoch: 73, Training Loss: 0.32, Validation Loss: 0.29, accuracy = 0.94
Epoch: 74, Training Loss: 0.31, Validation Loss: 0.29, accuracy = 0.94
Epoch: 75, Training Loss: 0.30, Validation Loss: 0.28, accuracy = 0.94
Epoch: 76, Training Loss: 0.30, Validation Loss: 0.28, accuracy = 0.94
Epoch: 77, Training Loss: 0.29, Validation Loss: 0.27, accuracy = 0.94
Epoch: 78, Training Loss: 0.28, Validation Loss: 0.27, accuracy = 0.94
Epoch: 79, Training Loss: 0.28, Validation Loss: 0.26, accuracy = 0.94
Epoch: 80, Training Loss: 0.27, Validation Loss: 0.26, accuracy = 0.94
Epoch: 81, Training Loss: 0.26, Validation Loss: 0.25, accuracy = 0.94
Epoch: 82, Training Loss: 0.26, Validation Loss: 0.25, accuracy = 0.94
Epoch: 83, Training Loss: 0.25, Validation Loss: 0.24, accuracy = 0.94
Epoch: 84, Training Loss: 0.25, Validation Loss: 0.24, accuracy = 0.94
Epoch: 85, Training Loss: 0.24, Validation Loss: 0.23, accuracy = 0.94
Epoch: 86, Training Loss: 0.23, Validation Loss: 0.23, accuracy = 0.94
Epoch: 87, Training Loss: 0.23, Validation Loss: 0.22, accuracy = 0.94
Epoch: 88, Training Loss: 0.22, Validation Loss: 0.22, accuracy = 0.94
Epoch: 89, Training Loss: 0.22, Validation Loss: 0.21, accuracy = 0.94
Epoch: 90, Training Loss: 0.21, Validation Loss: 0.21, accuracy = 0.94
Epoch: 91, Training Loss: 0.20, Validation Loss: 0.20, accuracy = 0.94
Epoch: 92, Training Loss: 0.20, Validation Loss: 0.20, accuracy = 0.96
Epoch: 93, Training Loss: 0.19, Validation Loss: 0.19, accuracy = 0.96
Epoch: 94, Training Loss: 0.19, Validation Loss: 0.19, accuracy = 0.96
Epoch: 95, Training Loss: 0.18, Validation Loss: 0.18, accuracy = 0.96
Epoch: 96, Training Loss: 0.18, Validation Loss: 0.18, accuracy = 0.96
Epoch: 97, Training Loss: 0.17, Validation Loss: 0.17, accuracy = 0.96
Epoch: 98, Training Loss: 0.17, Validation Loss: 0.17, accuracy = 0.96
Epoch: 99, Training Loss: 0.16, Validation Loss: 0.16, accuracy = 0.96
Epoch: 100, Training Loss: 0.16, Validation Loss: 0.16, accuracy = 0.96
  • 模型保存
torch.save(model,"K:\\classifier3.pt")#保存完整模型
  • 模型加载
load_model = torch.load("K:\\classifier3.pt")
  • 预测
img_path = r"K:\imageData\polarity\data3\val\pos\00002.bmp"
#img_path = r"K:\imageData\polarity\data3\val\neg\00001.bmp"
labels = ["neg","pos"]
img = Image.open(img_path)
img = img_transforms(img).to(device)
img = torch.unsqueeze(img,0)

model.eval()
prediction = F.softmax(model(img),dim=1)
prediction = prediction.argmax()
print(labels[prediction])
pos
  • 网络模型可视化
import netron
netron.start("K:\\classifier3.pt")
Serving 'K:\classifier3.pt' at http://localhost:8080





('localhost', 8080)
  • 训练过程可视化
def visualize(train_loss,val_loss,val_acc):
    train_loss = np.array(train_loss)
    val_loss = np.array(val_loss)
    val_acc = np.array(val_acc)
    plt.grid(True)
    plt.xlabel("epoch")
    plt.ylabel("value")
    plt.title("train_loss and valid_acc")
    plt.plot(np.arange(len(val_acc)),val_acc, label=r"valid_acc",c="g")
    plt.plot(np.arange(len(train_loss)),train_loss,label=r"train_loss",c="r")
    plt.legend()
    plt.savefig("K:\\a.png")
    
visualize(train_loss_list,valid_loss_list,valid_accuracy_list)

在这里插入图片描述

注:从图像中的训练损失和验证准确度来看,训练的轮次还应该再增加,因为训练损失还在下降,验证准确度还在上升,没有到达饱和状态。

  • 0
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值