深度学习(9):Inception危险物品检测

目标:基于Inception网络实现对危险物品检测,将危险物品图片或视频经过图像预处理后输入模型推理,最后将检测结果进行可视化输出。

一、原理

Google的Inception网络介绍

Inception为Google开源的CNN模型,至今已经公开四个版本,每一个版本都是基于大型图像数据库ImageNet中的数据训练而成。因此我们可以直接利用Google的Inception模型来实现图像分类。

二、过程

1.导入库

首先导入需要的组件包,包括tensorflow、keras、IPython等,代码如下:

# 安装完成需要重启kernel
!pip3 install pygame
!pip3 install opencv_python
# !pip3 install tensorflow==1.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

from PIL import Image
import numpy as np
import cv2
import time
import os
import sys
import logging as log
import pygame
from tensorflow.keras.applications.inception_v3 import decode_predictions
from keras.applications import InceptionV3
from keras.applications import imagenet_utils

from IPython.display import clear_output, Image, display, HTML

 

 2.导入数据

#准备数据,从OSS中获取数据并解压到当前目录:

import oss2
access_key_id = os.getenv('OSS_TEST_ACCESS_KEY_ID', 'LTAI4G1MuHTUeNrKdQEPnbph')
access_key_secret = os.getenv('OSS_TEST_ACCESS_KEY_SECRET', 'm1ILSoVqcPUxFFDqer4tKDxDkoP1ji')
bucket_name = os.getenv('OSS_TEST_BUCKET', 'mldemo')
endpoint = os.getenv('OSS_TEST_ENDPOINT', 'https://oss-cn-shanghai.aliyuncs.com')
# 创建Bucket对象,所有Object相关的接口都可以通过Bucket对象来进行
bucket = oss2.Bucket(oss2.Auth(access_key_id, access_key_secret), endpoint, bucket_name)
# 下载到本地文件
bucket.get_object_to_file('data/c12/danger_detect_data.zip', 'danger_detect_data.zip')
#解压数据
!unzip -o -q danger_detect_data.zip -d danger_detect_input 
!rm -rf __MACOSX
!ls danger_detect_input -ilht

3.定义工具方法

#数据预处理
def pre_process_image(image, img_height=299):
    n, c, h, w = [1, img_height, img_height,3]
    processedImg = image
    # 图像归一化处理
    processedImg = (np.array(processedImg) - 0) / 255.0

    # Change data layout from HWC to CHW
    processedImg = processedImg.transpose((2, 0, 1))
    processedImg = processedImg.reshape((n, c, h, w))

    return image, processedImg

# 视频显示
def arrayShow(imageArray):
    ret, png = cv2.imencode('.jpg', imageArray)
    return Image(png)

# 将dlib中rect对像转化为(top, right, bottom, left)形式
def _rect_to_css(rect):
    return rect.top(), rect.right(), rect.bottom(), rect.left()

# 确保(top, right, bottom, left)在图片内部
def _trim_css_to_bounds(css, image_shape):
    return max(css[0], 0), min(css[1], image_shape[1]), min(css[2], image_shape[0]), max(css[3], 0)

4.加载模型

print("[INFO] loading InceptionV3 model...")
model = InceptionV3(weights="imagenet")

Inception-v3:针对Inception-v2的升级,增加了以下内容:(1)RMSProp优化器。(2)分解为7*7卷积。(3)辅助分类BatchNorm。(4)标签平滑(Label Smoothing,添加到损失公式中的正则化组件类型,防止网络过于准确,防止过度拟合)。

5.查看模型信息

#查看模型信息
model.summary()

6.查看模型的输入要求

#查看模型的输入要求
model.input

 7.查看模型的输出
 

#查看模型的输出
model.output

8.初始化参数

#可视化字体颜色
textColor = (255, 0, 0)
#摄像头输入图像宽度
camera_width = 299*2
#摄像头输入图像高度
camera_height = 299*2
#定义模型输入
inputShape = (299, 299)

#图片地址
path = "danger_detect_input/"
#初始化声音报警
pygame.init()
alarm = None
try:
    pygame.mixer.init()
    pygame.mixer.pre_init(44100, -16, 2, 2048)
    alarm = pygame.mixer.music.load('alarm.mp3')
except:
    alarm = None

 9.危险物品检测

from pygame.locals import *

input_file = "danger_detect_input/input/out1.mov"
video_capture = cv2.VideoCapture(input_file)

video_capture.set(cv2.CAP_PROP_FPS, 10)
video_capture.set(cv2.CAP_PROP_FRAME_WIDTH, camera_width)
video_capture.set(cv2.CAP_PROP_FRAME_HEIGHT, camera_height)

ret, frame = video_capture.read()

elapsedTime = 0
fps = ""

danger_classes = ["assault_rifle","lighter"]
  
    
print("begin process..",ret)
while ret:
    t1 = time.time()

    ret, frame = video_capture.read()
    if not ret: break

    frame = cv2.resize(frame, inputShape)

    time.sleep(2)   
    _, image = pre_process_image(frame)


    #模型预测
    preds = model.predict(image)
    
    #解析识别结果
    P = imagenet_utils.decode_predictions(preds)

    
    result =  P[0][0]

    cv2.putText(frame, fps, (20, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 255), 1, cv2.LINE_AA)
    cv2.putText(frame, str(result[1]+" prob:"+str(result[2])), (20, 35), cv2.FONT_HERSHEY_SIMPLEX, 0.5, textColor, 1, cv2.LINE_AA)
    if result[1] in danger_classes and alarm:
        pygame.mixer.music.play(0)
     # 清空绘图空间
    clear_output(wait=True)
    # 显示处理结果
    display(arrayShow(frame))
    

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
    elapsedTime = time.time() - t1
    fps = "{:.1f} FPS".format(1 / elapsedTime)

cv2.destroyAllWindows()

  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

流萤数点

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值