完整的已训练模型的使用
1. 已训练模型的使用相关事宜
- 当我们通过训练集训练模型,并通过测试集评判模型,按着评价指标获得最终的模型后,最终目的是将这个模型运用到实际任务中,应用过程中存在几个需要关注的问题:
- 数据的预处理:我们需要把数据转为一定的维度
(batch_size, data_dim)
batch_size:
如果没有多个样本进行预测,可以直接设置为 1 ,这是必须设置的,因为模型训练的时候接受的数据的维度是包含batch_size
的data_dim:
这是每个样本的数据维度,需要确保进行预测的样本维度与模型训练时使用的维度完全相同
- 模型的加载:除了按照固定的方式加载模型外,我们还需要考虑模型是否基于GPU训练的还是CPU训练的,如果基于GPU训练的模型在CPU中加载,需要指定模型的参数,才能确保模型正常使用
- 数据的预处理:我们需要把数据转为一定的维度
2. 使用已经训练的模型运行实际任务
-
假设我们已经使用GPU训练好了之前建立
CIFAR10
分类网络架构,模型的保存目录为model_path = "./models/model.pth"
-
有一个待分类的图片保存路径为
image_path = "./imgs/image.jpg"
-
第一步:数据的加载,并进行维度矫正
image = Image.open(image_path) # load image transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((32, 32)), torchvision.transforms.ToTensor(), ]) # Transform image = transform(image) image = torch.reshape(image, (1, 3, 32, 32))
- 对于加载的图片
image
是Image
类型的数据,通过Resize
,将图片的尺寸进行修改为符合模型输入的 32*32 - 再通过
ToTensor()
将数据转换为Tensor
数据类型,会将图像的像素值从 [0, 255] 缩放到 [0, 1] 的范围,并且将图像从 HWC(高度、宽度、通道)格式转换为 CHW(通道、高度、宽度)格式 - 最后通过一个
reshape
给图像的张量添加一个维度,这个维度是batch_size
表明共计有多少个样本 --> 至此,该图像已经可以投入到模型中使用了
- 对于加载的图片
-
第二步:模型的加载
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.load(model_path, map_location=device)
- 要确定本次预测要使用什么
device
然后在模型的加载过程中,指定map_location
参数为选择的device
这样可以保证模型是可以在当前设备上正常运行的
- 要确定本次预测要使用什么
-
第三步:数据的预测
image = image.to(device) model.eval() with torch.no_grad(): output = model(image)
- 将
image
数据转入到需要使用的device
中,这里一定要保证model
和image
是使用的同一个device
- 预测过程就是模型的测试过程,为了规范性,建议先使用
eval()
方法,并设定模型的梯度不发生变化with torch.no_grad
- 然后将数据投入到模型中,获取最后的
output
即可!
- 将
3. 完整的代码
-
主要包括基本参数定义、数据加载和格式转换、模型加载、数据预测
from PIL import Image import torch from models import Model import torchvision # define para model_path = "./models/model.pth" image_path = "./imgs/image.jpg" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # load and transform image image = Image.open(image_path) transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((32, 32)), torchvision.transforms.ToTensor(), ]) image = transform(image) image = torch.reshape(image, (1, 3, 32, 32)) image = image.to(device) # load model model = torch.load(model_path, map_location=device) # predict model.eval() with torch.no_grad(): output = model(image) print(output)