pytorch目标检测ssd六__预测效果与预测过程代码详解

本篇博客是我学习(https://blog.csdn.net/weixin_44791964)博主写的pytorch的ssd的博客后写的,大家可以直接去看这位博主的博客(https://blog.csdn.net/weixin_44791964/article/details/104981486)。这位博主在b站还有配套视频,传送门:(https://www.bilibili.com/video/BV1A7411976Z)。这位博主的在GitHub的源代码(https://github.com/bubbliiiing/ssd-pytorch)。 侵删

这篇博客将要讲述ssd的预测效果与预测过程是怎么实现的

# Adapted from https://github.com/Hakuyume/chainer-ssd
def decode(loc, priors, variances):
    boxes = torch.cat((
        #首先计算先验框调整之后的中心的位置
        priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
        #计算出调整后的先验框的宽和高
        priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
    #计算先验框的左上角,
    boxes[:, :2] -= boxes[:, 2:] / 2
    #计算先验框的右下角
    boxes[:, 2:] += boxes[:, :2]
    return boxes
from ssd import SSD
from PIL import Image

#首先ssd.py的ssd类
ssd = SSD()

while True:
    img = input('Input image filename:')
    try:
        image = Image.open(img)
    except:
        print('Open Error! Try again!')
        continue
    else:
        r_image = ssd.detect_image(image)
        r_image.show()

#import cv2
import numpy as np
import colorsys
import os
import torch
from nets import ssd
import torch.backends.cudnn as cudnn
from utils.config import Config
from utils.box_utils import letterbox_image,ssd_correct_boxes
from PIL import Image,ImageFont, ImageDraw
from torch.autograd import Variable

MEANS = (104, 117, 123)
class SSD(object):
    _defaults = {
        #训练好的模型的权重存放的位置
        "model_path": 'model_data/ssd_weights.pth',
        #我们分的类,对应的txt文件
        "classes_path": 'model_data/voc_classes.txt',
        #输入图片的大小
        "model_image_size" : (300, 300, 3),
        "confidence": 0.5,
    }

    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"

    #---------------------------------------------------#
    #   初始化RFB
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        self.class_names = self._get_class()
        self.generate()
    #---------------------------------------------------#
    #   获得所有的分类
    #---------------------------------------------------#
    def _get_class(self):
        classes_path = os.path.expanduser(self.classes_path)
        with open(classes_path) as f:
            class_names = f.readlines()
        class_names = [c.strip() for c in class_names]
        return class_names
    #---------------------------------------------------#
    #   获得所有的分类
    #---------------------------------------------------#
    def generate(self):
        # 计算总的种类
        self.num_classes = len(self.class_names) + 1

        # 载入模型,如果原来的模型里已经包括了模型结构则直接载入。
        # 否则先构建模型再载入
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = ssd.get_ssd("test",self.num_classes)
        self.net = model
        model.load_state_dict(torch.load(self.model_path))

        self.net = torch.nn.DataParallel(self.net)
        cudnn.benchmark = True
        self.net = self.net.cuda()

        print('{} model, anchors, and classes loaded.'.format(self.model_path))
        # 画框设置不同的颜色
        hsv_tuples = [(x / len(self.class_names), 1., 1.)
                      for x in range(len(self.class_names))]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(
            map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
                self.colors))

    #---------------------------------------------------#
    #   检测图片
    #---------------------------------------------------#
    #输入参数是我们输入进来的图片
    def detect_image(self, image):
        #首先获得我们输入进来的图片的shape
        image_shape = np.array(np.shape(image)[0:2])
        #给我们输入进来的图片加上letterbox_image,目的是防止我们的图片失真,给图片边缘加上灰条,使图片整体的大小变成300x300x3的大小
        crop_img = np.array(letterbox_image(image, (self.model_image_size[0],self.model_image_size[1])))
        #将图片转化为numpy的float64的大小 
        photo = np.array(crop_img,dtype = np.float64)

        # 图片预处理,归一化,这里加了cuda
        photo = Variable(torch.from_numpy(np.expand_dims(np.transpose(crop_img-MEANS,(2,0,1)),0)).cuda().type(torch.FloatTensor))
        """
        这里将我们的图片传入网络进行预测
        这里的预测结果其实是包含了:
        我们这张图片里面,每一个类所对应的得分最高的两百个框的参数
        """
        preds = self.net(photo)
        
        top_conf = []
        top_label = []
        top_bboxes = []
        #下面就是对每一个类进行遍历,从1开始(0是指背景),然后判断先验框是否大于了self.confidence,
        #遍历完所有的类之后就得到我们这张图片所有的预测结果了
        for i in range(preds.size(1)):
            j = 0
            while preds[0, i, j, 0] >= self.confidence:
                score = preds[0, i, j, 0]
                label_name = self.class_names[i-1]
                pt = (preds[0, i, j, 1:]).detach().numpy()
                #如果大于了self.confidence就把这个框及其得分标签等保留下来
                coords = [pt[0], pt[1], pt[2], pt[3]]
                top_conf.append(score)
                top_label.append(label_name)
                top_bboxes.append(coords)
                j = j + 1
        # 将预测结果进行解码,如果检测到预测框里面有物体,就开始检测物体,如果没有物体,就返回这张图片了
        if len(top_conf)<=0:
            return image
        top_conf = np.array(top_conf)
        top_label = np.array(top_label)
        top_bboxes = np.array(top_bboxes)
        top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:,0],-1),np.expand_dims(top_bboxes[:,1],-1),np.expand_dims(top_bboxes[:,2],-1),np.expand_dims(top_bboxes[:,3],-1)

        # 去掉灰条
        boxes = ssd_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.model_image_size[0],self.model_image_size[1]]),image_shape)

        font = ImageFont.truetype(font='model_data/simhei.ttf',size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32'))

        thickness = (np.shape(image)[0] + np.shape(image)[1]) // self.model_image_size[0]

        for i, c in enumerate(top_label):
            predicted_class = c
            score = top_conf[i]

            top, left, bottom, right = boxes[i]
            top = top - 5
            left = left - 5
            bottom = bottom + 5
            right = right + 5

            top = max(0, np.floor(top + 0.5).astype('int32'))
            left = max(0, np.floor(left + 0.5).astype('int32'))
            bottom = min(np.shape(image)[0], np.floor(bottom + 0.5).astype('int32'))
            right = min(np.shape(image)[1], np.floor(right + 0.5).astype('int32'))

            # 画框框
            label = '{} {:.2f}'.format(predicted_class, score)
            draw = ImageDraw.Draw(image)
            label_size = draw.textsize(label, font)
            label = label.encode('utf-8')
            print(label)
            
            if top - label_size[1] >= 0:
                text_origin = np.array([left, top - label_size[1]])
            else:
                text_origin = np.array([left, top + 1])

            for i in range(thickness):
                draw.rectangle(
                    [left + i, top + i, right - i, bottom - i],
                    outline=self.colors[self.class_names.index(predicted_class)])
            draw.rectangle(
                [tuple(text_origin), tuple(text_origin + label_size)],
                fill=self.colors[self.class_names.index(predicted_class)])
            draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
            del draw
        return image


总结一下,预测其实就是先将我们所有的先验框都来遍历一遍,遍历的时候将每个先验框框都用类别(就是我们分出的类)来对比一遍,然后选出每个类别对应得分最高的200个先验框,然后再次进行遍历,来判断是先验框否满足了我们设定的confidence的值,然后再判断是否含有物体,如果含有物体的话,我们就进行预测,反之就不进行预测了。

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值