PyTorch:保存/加载训练好的模型测试

保存
torch.save(model.state_dict(), './cnn.pth')

加载
model = VGG16() #加载模型前要创建一个模型的实例对象
model.load_state_dict(torch.load("./cnn.pth"))

例子

import torch
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from tqdm import tqdm
from PIL import Image

'''定义网络模型'''
class VGG16(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG16, self).__init__()
        self.features = nn.Sequential(
            #1
            nn.Conv2d(3,64,kernel_size=3,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            #2
            nn.Conv2d(64,64,kernel_size=3,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2,stride=2),
            #3
            nn.Conv2d(64,128,kernel_size=3,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            #4
            nn.Conv2d(128,128,kernel_size=3,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2,stride=2),
            #5
            nn.Conv2d(128,256,kernel_size=3,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            #6
            nn.Conv2d(256,256,kernel_size=3,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            #7
            nn.Conv2d(256,256,kernel_size=3,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2,stride=2),
            #8
            nn.Conv2d(256,512,kernel_size=3,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            #9
            nn.Conv2d(512,512,kernel_size=3,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            #10
            nn.Conv2d(512,512,kernel_size=3,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2,stride=2),
            #11
            nn.Conv2d(512,512,kernel_size=3,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            #12
            nn.Conv2d(512,512,kernel_size=3,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            #13
            nn.Conv2d(512,512,kernel_size=3,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2,stride=2),
            nn.AvgPool2d(kernel_size=1,stride=1),
            )
        self.classifier = nn.Sequential(
            #14
            nn.Linear(512,4096),
            nn.ReLU(True),
            nn.Dropout(),
            #15
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            #16
            nn.Linear(4096,num_classes),
            )
        #self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out = self.features(x)
        #        print(out.shape)
        out = out.view(out.size(0), -1)
        #        print(out.shape)
        out = self.classifier(out)
        #        print(out.shape)
        return out

'''创建model实例对象,并检测是否支持使用GPU'''
model = VGG16()

use_gpu = torch.cuda.is_available()  # 判断是否有GPU加速
if use_gpu:
    model = model.cuda()

model.eval()

'''测试'''
def prediect(img_path):
    model.load_state_dict(torch.load("./cnn.pth"))
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.no_grad()

    transform_valid = transforms.Compose([
        transforms.Resize((32, 32), interpolation=2),
        transforms.ToTensor()
    ]
    )
    img = Image.open(img_path)
    img = transform_valid(img).unsqueeze(0)  # 拓展维度

    if use_gpu:
        img = Variable(img, volatile=True).cuda()
        # label = Variable(label, volatile=True).cuda()
    else:
        img = Variable(img)
        # label = Variable(label)
    out = model(img)
    
    _, pred = torch.max(out, 1)  # 求出out最大值索引

    print('this picture maybe :', classes[pred])

if __name__ == '__main__':
    prediect('./Test_Image/dog.jpg')

  • 4
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
测试加载训练好的模型,可以按照以下步骤进行操作: 1. 导入需要的库:import torchvision.models as models 2. 创建模型实例对象:model = models.resnet50(pretrained=True) 3. 加载训练好的模型参数:model.load_state_dict(torch.load("路径/模型文件.pth")) 在这个过程中,我们可以使用torchvision库中提供的已经训练好的模型,如resnet50。通过设置pretrained=True,我们可以加载已经在大规模数据集上进行预训练模型参数。接下来,我们可以使用load_state_dict()方法加载我们训练好的模型的参数,该方法需要传入一个模型参数的文件路径。 请注意,加载训练好的模型时,需要确保训练好的模型和要加载模型具有相同的结构和参数。如果模型结构不同,需要根据实际情况进行相应的调整,例如只加载部分预训练模型的参数。可以参考上述提供的代码示例来加载模型参数。 总结起来,测试加载训练好的模型的步骤如下: 1. 导入需要的库 2. 创建模型实例对象 3. 加载训练好的模型参数 希望对您有帮助!<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [pytorch加载训练好的模型用来测试或者处理](https://blog.csdn.net/weixin_40244676/article/details/117251048)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* [PyTorch保存/加载训练好的模型测试](https://blog.csdn.net/cd_yourheart/article/details/113925708)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值