加载pytorch已有模型,修改最后分类头

在加载pytorch已有模型的时候,我们必须要明确的事情:

1 如何获取到pytorch所提供的模型,通过什么方式。
2 模型的结构,也就是模型的每个层的名字(key)。
3 我们要把需要加载的模型,尽量封装成一个类。

下面我们针对上面来给出答案。
答1:以 resnet18 举例

# ----------------1 导入库 -----------------
import torchvision.models as model
# -------------2 将resnet18导入到新模型。--------------
base_model = 'resnet18'
if 'resnet' in base_model:
	model = getattr(model,base_model)

答2 :我们在了解模型的时候,经常使用 dict(model.named_parameters()),它会返回一个字典,我们通过 字典.items()来得到字典的key和value值。我们要知道最后一层的分类层名字叫什么。

for (key,value) in dict(model.named_parameters()).items():
    print(key)

在这里插入图片描述
最后一层的名字叫 fc ,这样我们可以通过最后一层的名字来修改最后一层。

num_class = 51
fc = getattr(model, 'fc')
feature_dim = fc.in_features
setattr(model,'fc',nn.Linear(feature_dim,num_class))
print(model)

在这里插入图片描述
这样就把最后一层修改完成了。

答3 :最后封装成新的模型类

import torchvision.models as model
import torch.nn as nn


class Model(nn.Module):
    def __init__(self, num_class,base_model= 'resnet18'):
        super().__init__()
        self._prepare_base_model(num_class = num_class,base_model = base_model )
  
    def _prepare_base_model(self, base_model,num_class):  
        if 'resnet' in base_model:
            self.model = getattr(model, base_model)(pretrained=True)
            feature_dim = getattr(self.model, 'fc').in_features
            setattr(self.model,'fc',nn.Linear(feature_dim,num_class))      
        else:
            raise ValueError('Unknown base model: {}'.format(base_model))
            
    def forward(self,x):
        out = self.model(x)
        return out

net = Model(num_class=51,base_model='resnet18')
print(net)

在这里插入图片描述

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
可以利用PyTorch和OpenCV进行已有模型分类结果显示。首先,需要加载已有PyTorch模型和相应的标签文件。然后,通过OpenCV调用摄像获取图像,并将图像处理为模型所需的尺寸。最后,将处理后的图像输入到PyTorch模型中进行分类,并显示分类结果。 以下是一个简单的示例代码: ``` import torch import cv2 import numpy as np # load model model = torch.load('model.pth', map_location=torch.device('cpu')) model.eval() # load labels with open('labels.txt', 'r') as f: labels = [line.strip() for line in f.readlines()] # initialize camera cap = cv2.VideoCapture(0) while True: # read frame from camera ret, frame = cap.read() # preprocess image img = cv2.resize(frame, (224, 224)) img = np.transpose(img, (2, 0, 1)) img = np.expand_dims(img, axis=0) img = img / 255.0 img = torch.tensor(img, dtype=torch.float32) # classify image with torch.no_grad(): output = model(img) pred = torch.argmax(output, dim=1) # display classification result cv2.putText(frame, labels[pred.item()], (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) cv2.imshow('frame', frame) # exit on 'q' key press if cv2.waitKey(1) & 0xFF == ord('q'): break # release resources cap.release() cv2.destroyAllWindows() ``` 在这个示例中,模型文件为'model.pth',标签文件为'labels.txt'。摄像通过OpenCV的'cv2.VideoCapture(0)'进行初始化,在每次循环中读取一帧图像。图像被处理为模型所需的大小(224x224),并将其输入到模型中进行分类最后分类结果被绘制在原始图像上,并在窗口中显示。按下'q'键退出程序。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值