第P4周:猴痘病识别

重点: 

  1. 训练过程中保存效果最好的模型参数。加载最佳模型参数识别本地的一张图片。
  2. 调整网络结构使测试集accuracy到达88%(重点)。
  3. 加载最佳模型参数识别本地的一张图片。

环境: 

  1. 语言环境:Python3.8
  2. 编译器:jupyter notebook
  3. 深度学习环境:Pytorch

一、 前期准备

1. 导入必要的包

import os
import pathlib
import PIL
import random

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision import transforms, datasets

设GPU,如果设备上支持GPU就使用GPU,否则使用CPU 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

 输出:(当前设备没有GPU,先用CPU凑合...)

device(type='cpu')

2. 导入数据

读取数据路径

重点:从文件中导入数据

首先查看当前py文件所在目录,和数据集目录一致可使用相对路径导入,否则以绝对路径导入数据

os.getcwd() # 获取当前py文件路径

输出:

'C:\\Users\\98613'

 修改py目录路径

new_dir = 'D:\pytorch'
os.chdir(new_dir) #修改当前系统路径
os.getcwd() # 再次查看py文件目录

输出:

'D:\\pytorch'

获取并查看数据的标签类别名称 

data_dir = './猴痘病识别/'
# 使用pathlib.Path()函数将字符串类型的文件夹路径转换为pathlib.Path对象
data_dir = pathlib.Path(data_dir)
# 使用glob()方法获取data_dir路径下的所有文件路径,并以列表形式存储在data_paths中
data_paths = list(data_dir.glob('*'))
# 通过split()函数对data_paths中的每个文件路径执行分割操作,获得各个文件所属的类别名称,并存储在classeNames中
classeNames = [str(path).split('\\')[1] for path in data_paths]
# 打印classeNames列表,显示每个文件所属的类别名称
classeNames

输出:      'Monkeypox' 表示猴痘病类别,'Others' 表示非猴痘病类别

['Monkeypox', 'Others']

查看输出数据格式:

data_dir

WindowsPath('猴痘病识别')

data_paths

[WindowsPath('猴痘病识别/Monkeypox'), WindowsPath('猴痘病识别/Others')]
 数据可视化
import matplotlib.pyplot as plt
from PIL import Image


# 指定图像文件夹路径
image_folder = './猴痘病识别/Monkeypox/'


# 获取文件夹中的所有图像文件
image_files = [f for f in os.listdir(image_folder) if f.endswith((".jpg", ".png", ".jpeg"))]


# 创建Matplotlib图像
fig, axes = plt.subplots(3, 8, figsize=(16, 6))


# 使用列表推导式加载和显示图像
for ax, img_file in zip(axes.flat, image_files):
    img_path = os.path.join(image_folder, img_file)
    img = Image.open(img_path)
    ax.imshow(img)
    ax.axis('off')


# 显示图像
plt.tight_layout()
plt.show()

图像可能引起不适...在这里先不放了

os文件路径相关操作:

os.getcwd():得到当前工作目录,即当前Python脚本工作的目录路径
os.listdir():返回指定目录下的所有文件和目录名
os.chdir("path") :换路径
os.path.join(path1, path2):连接两个str格式路径
 
其他图片加载操作:
Image.open(path):path为图片地址,打开该图片
读取数据
train_dir = './猴痘病识别/'
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(  # # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]  
    )  # # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])
total_data = datasets.ImageFolder(train_dir, transform=train_transforms)
total_data

 输出:

Dataset ImageFolder
    Number of datapoints: 2142
    Root location: ./猴痘病识别/
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )
total_data.class_to_idx

输出:

 {'Monkeypox': 0, 'Others': 1}

torchvision相关

torchvision是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision.transforms主要是用于常见的一些图形变换。以下是torchvision的构成:

1.torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
2.torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
3.torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
4.torchvision.utils: 其他的一些有用的方法。

本文的主题是其中的torchvision.transforms.Compose()类。这个类的主要作用是串联多个图片变换的操作
from torchvision.transforms import transforms

train_transforms = transforms.Compose([
transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸
transforms.RandomRotation(degrees=(-10, 10)), # 随机旋转,-10到10度之间随机选
transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转 选择一个概率概率
transforms.RandomVerticalFlip(p=0.5), # 随机垂直翻转
transforms.RandomPerspective(distortion_scale=0.6, p=1.0), # 随机视角
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)), # 随机选择的高斯模糊模糊图像
transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
mean=[0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

3. 划分数据集

按照8:2的比例将全量数据随机划分为训练集和测试集。训练集和测试集中各包含猴痘病和非猴痘病的图片

train_size = int(len(total_data)*0.8)
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
train_dataset, test_dataset

输出:

(<torch.utils.data.dataset.Subset at 0x1a434f04a60>,
 <torch.utils.data.dataset.Subset at 0x1a434f04e50>)

 查看训练集和测试集数据量

train_size, test_size

 输出:   

(1713, 429)

训练集1713张图片,测试集429张图片

4. 加载数据 

batch_size = 32
 
train_dl = torch.utils.data.DataLoader(train_dataset,
                                       batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=1)
 
test_dl = torch.utils.data.DataLoader(test_dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1)
for X, y in test_dl: # 查看测试集某一个batch的图片
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

输出: 

Shape of X [N, C, H, W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

torch.utils.data.DataLoader()参数详解

torch.utils.data.DataLoader 是 PyTorch 中用于加载和管理数据的一个实用工具类。它允许你以小批次的方式迭代你的数据集,这对于训练神经网络和其他机器学习任务非常有用。DataLoader 构造函数接受多个参数,下面是一些常用的参数及其解释:

  • dataset(必需参数):这是你的数据集对象,通常是 torch.utils.data.Dataset 的子类,它包含了你的数据样本
  • batch_size(可选参数):指定每个小批次中包含的样本数。默认值为 1。
  • shuffle(可选参数):如果设置为 True,则在每个 epoch 开始时对数据进行洗牌,以随机打乱样本的顺序。这对于训练数据的随机性很重要,以避免模型学习到数据的顺序性。默认值为 False
  • num_workers(可选参数):用于数据加载的子进程数量。通常,将其设置为大于 0 的值可以加快数据加载速度,特别是当数据集很大时。默认值为 0,表示在主进程中加载数据
  • pin_memory(可选参数):如果设置为 True,则数据加载到 GPU 时会将数据存储在 CUDA 的锁页内存中,这可以加速数据传输到 GPU。默认值为 False
  • drop_last(可选参数):如果设置为 True,则在最后一个小批次可能包含样本数小于 batch_size 时,丢弃该小批次。这在某些情况下很有用,以确保所有小批次具有相同的大小。默认值为 False
  • timeout(可选参数):如果设置为正整数,它定义了每个子进程在等待数据加载器传递数据时的超时时间(以秒为单位)。这可以用于避免子进程卡住的情况。默认值为 0,表示没有超时限制
  • worker_init_fn(可选参数):一个可选的函数,用于初始化每个子进程的状态。这对于设置每个子进程的随机种子或其他初始化操作很有用
     

关于shuffle流程:

二、构建简单的CNN网络

对于一般的CNN网络来说,都是由特征提取网络和分类网络构成,其中特征提取网络用于提取图片的特征,分类网络用于将图片进行分类

⭐1. torch.nn.Conv2d()详解

函数原型:

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)

关键参数说明:

in_channels ( int ) – 输入图像中的通道数
out_channels ( int ) – 卷积产生的通道数
kernel_size ( int or tuple ) – 卷积核的大小
stride ( int or tuple , optional ) -- 卷积的步幅。默认值:1
padding ( int , tuple或str , optional ) – 添加到输入的所有四个边的填充。默认值:0
padding_mode (字符串,可选) – 'zeros', 'reflect', 'replicate'或'circular'. 默认:'zeros'
⭐2. torch.nn.Linear()详解

函数原型:

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

关键参数说明:

in_features:每个输入样本的大小
out_features:每个输出样本的大小
⭐3. torch.nn.MaxPool2d()详解

函数原型:

torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

关键参数说明:

kernel_size:最大的窗口大小
stride:窗口的步幅,默认值为kernel_size
padding:填充值,默认为0
dilation:控制窗口中元素步幅的参数
大家注意一下在卷积层和全连接层之间,我们可以使用之前是torch.flatten()也可以使用我下面的x.view()亦或是torch.nn.Flatten()。torch.nn.Flatten()与TensorFlow中的Flatten()层类似,前两者则仅仅是一种数据集拉伸操作(将二维数据拉伸为一维),torch.flatten()方法不会改变x本身,而是返回一个新的张量。而x.view()方法则是直接在原有数据上进行操作。

网络结构图:

import torch.nn.functional as F

class Network_bn(nn.Module):
    def __init__(self):
        super(Network_bn, self).__init__()
        """
        nn.Conv2d()函数:
        第一个参数(in_channels)是输入的channel数量
        第二个参数(out_channels)是输出的channel数量
        第三个参数(kernel_size)是卷积核大小
        第四个参数(stride)是步长,默认为1
        第五个参数(padding)是填充大小,默认为0
        """
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(12)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(12)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0)
        self.bn4 = nn.BatchNorm2d(24)
        self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)
        self.bn5 = nn.BatchNorm2d(24)
        self.fc1 = nn.Linear(24*50*50, len(classeNames))

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        x = self.pool(x)
        x = x.view(-1, 24*50*50)
        x = self.fc1(x)
        return x

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

model = Network_bn().to(device)
model

 输出:

Using cpu device
Out[8]: 
Network_bn(
  (conv1): Conv2d(3, 12, kernel_size=(5, 5), stride=(1, 1))
  (bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(12, 12, kernel_size=(5, 5), stride=(1, 1))
  (bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv2d(12, 24, kernel_size=(5, 5), stride=(1, 1))
  (bn4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(24, 24, kernel_size=(5, 5), stride=(1, 1))
  (bn5): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=60000, out_features=2, bias=True)
)
# 打印模型
from torchinfo import summary
summary(model)

输出:

=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Network_bn                               --
├─Conv2d: 1-1                            912
├─BatchNorm2d: 1-2                       24
├─Conv2d: 1-3                            3,612
├─BatchNorm2d: 1-4                       24
├─MaxPool2d: 1-5                         --
├─Conv2d: 1-6                            7,224
├─BatchNorm2d: 1-7                       48
├─Conv2d: 1-8                            14,424
├─BatchNorm2d: 1-9                       48
├─Linear: 1-10                           240,004
=================================================================
Total params: 266,320
Trainable params: 266,320
Non-trainable params: 0
=================================================================

三、 训练模型

1. 设置超参数

loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4 # 学习率
opt = torch.optim.SGD(model.parameters(),lr=learn_rate)

2. 编写训练函数
 

1. optimizer.zero_grad()

函数会遍历模型的所有参数,通过内置方法截断反向传播的梯度流,再将每个参数的梯度值设为0,即上一次的梯度记录被清空。

2. loss.backward()

PyTorch的反向传播(即tensor.backward())是通过autograd包来实现的,autograd包会根据tensor进行过的数学运算来自动计算其对应的梯度。

具体来说,torch.tensor是autograd包的基础类,如果你设置tensor的requires_grads为True,就会开始跟踪这个tensor上面的所有运算,如果你做完运算后使用tensor.backward(),所有的梯度就会自动运算,tensor的梯度将会累加到它的.grad属性里面去。

更具体地说,损失函数loss是由模型的所有权重w经过一系列运算得到的,若某个w的requires_grads为True,则w的所有上层参数(后面层的权重w)的.grad_fn属性中就保存了对应的运算,然后在使用loss.backward()后,会一层层的反向传播计算每个w的梯度值,并保存到该w的.grad属性中。

如果没有进行tensor.backward()的话,梯度值将会是None,因此loss.backward()要写在optimizer.step()之前。

3. optimizer.step()

step()函数的作用是执行一次优化步骤,通过梯度下降法来更新参数的值。因为梯度下降是基于梯度的,所以在执行optimizer.step()函数前应先执行loss.backward()函数来计算梯度。

注意:optimizer只负责通过梯度下降进行优化,而不负责产生梯度,梯度是tensor.backward()方法产生的
 

# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小,一共1713张图片
    num_batches = len(dataloader)   # 批次数目,54(1713/32)
 
    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
    
    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)
        
        # 计算预测误差
        pred = model(X)          # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()        # 反向传播
        optimizer.step()       # 每一步自动更新
        
        # 记录acc与loss
        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
            
    train_acc  /= size
    train_loss /= num_batches
 
    return train_acc, train_loss

3. 编写测试函数

测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器

def test (dataloader, model, loss_fn):
    size        = len(dataloader.dataset)  # 测试集的大小,一共429张图片
    num_batches = len(dataloader)          # 批次数目,14(429/32=13.4,向上取整)
    test_loss, test_acc = 0, 0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
            
            # 计算loss
            target_pred = model(imgs)
            loss        = loss_fn(target_pred, target)
            
            test_loss += loss.item()
            test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()
 
    test_acc  /= size
    test_loss /= num_batches
 
    return test_acc, test_loss

4. 正式训练


1. model.train()

model.train()的作用是启用 Batch Normalization 和 Dropout。

如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。

2. model.eval()

model.eval()的作用是不启用 Batch Normalization 和 Dropout。

如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质
 

epochs     = 20
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []
 
for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
    
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')

输出:

Using cpu device
Epoch: 1, Train_acc:61.5%, Train_loss:0.676, Test_acc:68.5%, Test_loss:0.615
Epoch: 2, Train_acc:69.1%, Train_loss:0.589, Test_acc:71.8%, Test_loss:0.610
Epoch: 3, Train_acc:74.2%, Train_loss:0.523, Test_acc:73.2%, Test_loss:0.540
Epoch: 4, Train_acc:77.3%, Train_loss:0.491, Test_acc:75.1%, Test_loss:0.531
Epoch: 5, Train_acc:78.6%, Train_loss:0.467, Test_acc:76.0%, Test_loss:0.511
Epoch: 6, Train_acc:82.1%, Train_loss:0.436, Test_acc:76.7%, Test_loss:0.499
Epoch: 7, Train_acc:83.4%, Train_loss:0.408, Test_acc:80.7%, Test_loss:0.471
Epoch: 8, Train_acc:84.2%, Train_loss:0.397, Test_acc:79.5%, Test_loss:0.461
Epoch: 9, Train_acc:86.1%, Train_loss:0.371, Test_acc:79.5%, Test_loss:0.440
Epoch:10, Train_acc:86.9%, Train_loss:0.356, Test_acc:79.7%, Test_loss:0.426
Epoch:11, Train_acc:86.7%, Train_loss:0.345, Test_acc:80.9%, Test_loss:0.425
Epoch:12, Train_acc:88.7%, Train_loss:0.333, Test_acc:82.1%, Test_loss:0.424
Epoch:13, Train_acc:88.6%, Train_loss:0.322, Test_acc:81.6%, Test_loss:0.415
Epoch:14, Train_acc:89.6%, Train_loss:0.312, Test_acc:82.3%, Test_loss:0.401
Epoch:15, Train_acc:90.3%, Train_loss:0.299, Test_acc:81.6%, Test_loss:0.395
Epoch:16, Train_acc:90.8%, Train_loss:0.289, Test_acc:82.3%, Test_loss:0.398
Epoch:17, Train_acc:90.4%, Train_loss:0.289, Test_acc:82.8%, Test_loss:0.380
Epoch:18, Train_acc:91.4%, Train_loss:0.276, Test_acc:83.7%, Test_loss:0.372
Epoch:19, Train_acc:91.9%, Train_loss:0.266, Test_acc:83.2%, Test_loss:0.379
Epoch:20, Train_acc:91.5%, Train_loss:0.261, Test_acc:83.0%, Test_loss:0.377
Done

 四、 结果可视化 

1. Loss与Accuracy图

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率
 
epochs_range = range(epochs)
 
plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
 
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
 
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

2. 指定图片进行预测

补充图片压缩的基本操作

1. torch.squeeze()详解

对数据的维度进行压缩,去掉维数为1的的维度

函数原型:

torch.squeeze(input, dim=None, *, out=None)

关键参数说明:

  • input (Tensor):输入Tensor
  • dim (int, optional):如果给定,输入将只在这个维度上被压缩

 实战案例:

x = torch.zeros(2, 1, 2, 1, 2)
x.size()
y = torch.squeeze(x)
y.size()
y = torch.squeeze(x, 0)
y.size()
y = torch.squeeze(x, 1)
y.size()

以上输出分别为:

torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 2, 2])
torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 2, 1, 2])

2. torch.unsqueeze()

对数据维度进行扩充。给指定位置加上维数为一的维度

函数原型:

torch.unsqueeze(input, dim)

关键参数说明:

  • input (Tensor):输入Tensor
  • dim (int):插入单例维度的索引

 实战案例:

x = torch.tensor([1, 2, 3, 4])
torch.unsqueeze(x, 0)
torch.unsqueeze(x, 1)
x.size()
torch.unsqueeze(x, 0).size()
torch.unsqueeze(x, 1).size()

以上输出分别为:

tensor([[1, 2, 3, 4]])
tensor([[1],
        [2],
        [3],
        [4]])
torch.Size([4])
torch.Size([1, 4])
torch.Size([4, 1])
import matplotlib.pyplot as plt
from PIL import Image
classes = list(total_data.class_to_idx)
def predict_one_image(image_path, model, transform, classes):
    test_img = Image.open(image_path).convert('RGB')
    plt.imshow(test_img) # 展示需要预测的图片

    test_img = transform(test_img)
    img = test_img.to(device).unsqueeze(0)

    model.eval()
    output = model(img)

    _, pred = torch.max(output, 1)
    pred_class = classes[pred]
    print(f'预测结果是: {pred_class}')


#预测训练集中的某张照片
predict_one_image(image_path='./猴痘病识别/M01_01_08.jpg',
                  model=model,
                  transform=train_transforms,
                  classes=classes)

输出:

预测结果是: Monkeypox

五、保存并加载模型

# 模型保存
PATH = './model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)

# 将参数加载到model当中
model.load_state_dict(torch.load(PATH, map_location=device))

输出: 

<All keys matched successfully>

小结

        本周内容相比上周新增通过训练的模型预测本地图片的部分,大致流程为:

        1.读取图片

        2.处理读取图片(方式和处理训练集图片一样)

        3. 维度扩充,图片索引0维度加上维数为一的维度

        4.模型预测,使用训练好的模型预测图片类别

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值