一、安装配置环境
1. 下载安装相关工具包
!pip install numpy pandas matplotlib requests tqdm opencv- python pillow - i https: // pypi. tuna. tsinghua. edu. cn/ simple
!pip3 install torch torchvision torchaudio - - extra- index- url https: // download. pytorch. org/ whl/ cu113
!pip install mmcv- full - f https: // download. openmmlab. com/ mmcv/ dist/ cu113/ torch1. 10.0 / index. html
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20220716 - mmclassification/ dataset/ SimHei. ttf
2. 创建目录
import os
os. mkdir( 'test_img' )
os. mkdir( 'output' )
os. mkdir( 'checkpoints' )
3. 下载模型文件、映射字典、和测试图片和视频
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20220716 - mmclassification/ dataset/ fruit30/ idx_to_labels. npy
!wget https: // zihao- openmmlab. obs. myhuaweicloud. com/ 20220716 - mmclassification/ test/ 0818 / test_fruits. jpg - P test_img
!wget https: // zihao- openmmlab. obs. myhuaweicloud. com/ 20220716 - mmclassification/ test/ 0818 / test_orange_2. jpg - P test_img
!wget https: // zihao- openmmlab. obs. myhuaweicloud. com/ 20220716 - mmclassification/ test/ 0818 / test_bananan. jpg - P test_img
!wget https: // zihao- openmmlab. obs. myhuaweicloud. com/ 20220716 - mmclassification/ test/ 0818 / test_kiwi. jpg - P test_img
!wget https: // zihao- openmmlab. obs. myhuaweicloud. com/ 20220716 - mmclassification/ test/ 0818 / test_石榴. jpg - P test_img
!wget https: // zihao- openmmlab. obs. myhuaweicloud. com/ 20220716 - mmclassification/ test/ 0818 / test_orange. jpg - P test_img
!wget https: // zihao- openmmlab. obs. myhuaweicloud. com/ 20220716 - mmclassification/ test/ 0818 / test_lemon. jpg - P test_img
!wget https: // zihao- openmmlab. obs. myhuaweicloud. com/ 20220716 - mmclassification/ test/ 0818 / test_火龙果. jpg - P test_img
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20220716 - mmclassification/ test/ watermelon1. jpg - P test_img
!wget https: // zihao- openmmlab. obs. cn- east- 3 . myhuaweicloud. com/ 20220716 - mmclassification/ test/ banana1. jpg - P test_img
!wget https: // zihao- openmmlab. obs. myhuaweicloud. com/ 20220716 - mmclassification/ test/ 0818 / fruits_video. mp4 - P test_img
4. 设置matplotlib中文字体
import matplotlib. pyplot as plt
% matplotlib inline
plt. rcParams[ 'font.sans-serif' ] = [ 'SimHei' ]
plt. rcParams[ 'axes.unicode_minus' ] = False
plt. plot( [ 1 , 2 , 3 ] , [ 100 , 500 , 300 ] )
plt. title( 'matplotlib中文字体测试' , fontsize= 25 )
plt. xlabel( 'X轴' , fontsize= 15 )
plt. ylabel( 'Y轴' , fontsize= 15 )
plt. show( )
二、用训练好的模型预测新图像
1. 导入工具包
import torch
import torchvision
import torch. nn. functional as F
import numpy as np
import pandas as pd
import matplotlib. pyplot as plt
% matplotlib inline
device = torch. device( 'cuda:0' if torch. cuda. is_available( ) else 'cpu' )
2. 设置matplotlib中文字体&导入pillow中文字体
plt. rcParams[ 'font.sans-serif' ] = [ 'SimHei' ]
plt. rcParams[ 'axes.unicode_minus' ] = False
from PIL import Image, ImageFont, ImageDraw
font = ImageFont. truetype( 'SimHei.ttf' , 32 )
3. 载入类别&训练好的模型
idx_to_labels = np. load( 'idx_to_labels.npy' , allow_pickle= True ) . item( )
model = torch. load( 'checkpoints/best-0.880.pth' )
model = model. eval ( ) . to( device)
4. 预处理
from torchvision import transforms
test_transform = transforms. Compose( [ transforms. Resize( 256 ) ,
transforms. CenterCrop( 224 ) ,
transforms. ToTensor( ) ,
transforms. Normalize(
mean= [ 0.485 , 0.456 , 0.406 ] ,
std= [ 0.229 , 0.224 , 0.225 ] )
] )
5. 对一张测试图像进行测试
img_path = 'test_img/watermelon1.jpg'
img_pil = Image. open ( img_path)
input_img = test_transform( img_pil)
input_img = input_img. unsqueeze( 0 ) . to( device)
pred_logits = model( input_img)
pred_softmax = F. softmax( pred_logits, dim= 1 )
n = 10
top_n = torch. topk( pred_softmax, n)
pred_ids = top_n[ 1 ] . cpu( ) . detach( ) . numpy( ) . squeeze( )
confs = top_n[ 0 ] . cpu( ) . detach( ) . numpy( ) . squeeze( )
draw = ImageDraw. Draw( img_pil)
for i in range ( n) :
class_name = idx_to_labels[ pred_ids[ i] ]
confidence = confs[ i] * 100
text = '{:<15} {:>.4f}' . format ( class_name, confidence)
print ( text)
draw. text( ( 50 , 100 + 50 * i) , text, font= font, fill= ( 255 , 0 , 0 , 1 ) )
6. 预测图+柱状图可视化
fig = plt. figure( figsize= ( 18 , 6 ) )
ax1 = plt. subplot( 1 , 2 , 1 )
ax1. imshow( img_pil)
ax1. axis( 'off' )
ax2 = plt. subplot( 1 , 2 , 2 )
x = idx_to_labels. values( )
y = pred_softmax. cpu( ) . detach( ) . numpy( ) [ 0 ] * 100
ax2. bar( x, y, alpha= 0.5 , width= 0.3 , color= 'yellow' , edgecolor= 'red' , lw= 3 )
plt. bar_label( ax, fmt= '%.2f' , fontsize= 10 )
plt. title( '{} 图像分类预测结果' . format ( img_path) , fontsize= 30 )
plt. xlabel( '类别' , fontsize= 20 )
plt. ylabel( '置信度' , fontsize= 20 )
plt. ylim( [ 0 , 110 ] )
ax2. tick_params( labelsize= 16 )
plt. xticks( rotation= 90 )
plt. tight_layout( )
fig. savefig( 'output/预测图+柱状图.jpg' )
7. 预测结果表格输出
pred_df = pd. DataFrame( )
for i in range ( n) :
class_name = idx_to_labels[ pred_ids[ i] ]
label_idx = int ( pred_ids[ i] )
confidence = confs[ i] * 100
pred_df = pred_df. append( { 'Class' : class_name, 'Class_ID' : label_idx, 'Confidence(%)' : confidence} , ignore_index= True )
display( pred_df)
三、预测视频文件
1. 导入工具包
import os
import time
import shutil
import tempfile
from tqdm import tqdm
import cv2
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib. pyplot as plt
% matplotlib inline
plt. rcParams[ 'axes.unicode_minus' ] = False
plt. rcParams[ 'font.sans-serif' ] = [ 'SimHei' ]
import gc
import torch
import torch. nn. functional as F
from torchvision import models
import mmcv
device = torch. device( 'cuda:0' if torch. cuda. is_available( ) else 'cpu' )
print ( 'device:' , device)
import matplotlib
matplotlib. use( 'Agg' )
2. 设置matplotlib中文字体&导入pillow中文字体
plt. rcParams[ 'font.sans-serif' ] = [ 'SimHei' ]
plt. rcParams[ 'axes.unicode_minus' ] = False
from PIL import ImageFont, ImageDraw
font = ImageFont. truetype( 'SimHei.ttf' , 32 )
3. 载入类别&训练好的模型
idx_to_labels = np. load( 'idx_to_labels.npy' , allow_pickle= True ) . item( )
model = torch. load( 'checkpoints/best-0.880.pth' )
model = model. eval ( ) . to( device)
4. 图像预处理
from torchvision import transforms
test_transform = transforms. Compose( [ transforms. Resize( 256 ) ,
transforms. CenterCrop( 224 ) ,
transforms. ToTensor( ) ,
transforms. Normalize(
mean= [ 0.485 , 0.456 , 0.406 ] ,
std= [ 0.229 , 0.224 , 0.225 ] )
] )
5. 图像分类预测函数(同上一个)
def pred_single_frame ( img, n= 5 ) :
'''
输入摄像头画面bgr-array,输出前n个图像分类预测结果的图像bgr-array
'''
img_rgb = cv2. cvtColor( img, cv2. COLOR_BGR2RGB)
img_pil = Image. fromarray( img_rgb)
input_img = test_transform( img_pil) . unsqueeze( 0 ) . to( device)
pred_logits = model( input_img)
pred_softmax = F. softmax( pred_logits, dim= 1 )
top_n = torch. topk( pred_softmax, n)
pred_ids = top_n[ 1 ] . cpu( ) . detach( ) . numpy( ) . squeeze( )
confs = top_n[ 0 ] . cpu( ) . detach( ) . numpy( ) . squeeze( )
draw = ImageDraw. Draw( img_pil)
for i in range ( len ( confs) ) :
pred_class = idx_to_labels[ pred_ids[ i] ]
text = '{:<15} {:>.3f}' . format ( pred_class, confs[ i] )
draw. text( ( 50 , 100 + 50 * i) , text, font= font, fill= ( 255 , 0 , 0 , 1 ) )
img_bgr = cv2. cvtColor( np. array( img_pil) , cv2. COLOR_RGB2BGR)
return img_bgr, pred_softmax
6. 视频预测
input_video = 'test_img/fruits_video.mp4'
6.1 可视化方案一:原始图像+预测结果文字
temp_out_dir = time. strftime( '%Y%m%d%H%M%S' )
os. mkdir( temp_out_dir)
print ( '创建临时文件夹 {} 用于存放每帧预测结果' . format ( temp_out_dir) )
imgs = mmcv. VideoReader( input_video)
prog_bar = mmcv. ProgressBar( len ( imgs) )
for frame_id, img in enumerate ( imgs) :
img, pred_softmax = pred_single_frame( img, n= 5 )
cv2. imwrite( f' { temp_out_dir} / { frame_id: 06d } .jpg' , img)
prog_bar. update( )
mmcv. frames2video( temp_out_dir, 'output/output_pred.mp4' , fps= imgs. fps, fourcc= 'mp4v' )
shutil. rmtree( temp_out_dir)
print ( '删除临时文件夹' , temp_out_dir)
6.2 可视化方案二:原始图像+预测结果文字+各类别置信度柱状图
def pred_single_frame_bar ( img) :
'''
输入pred_single_frame函数输出的bgr-array,加柱状图,保存
'''
img = cv2. cvtColor( img, cv2. COLOR_BGR2RGB)
fig = plt. figure( figsize= ( 18 , 6 ) )
ax1 = plt. subplot( 1 , 2 , 1 )
ax1. imshow( img)
ax1. axis( 'off' )
ax2 = plt. subplot( 1 , 2 , 2 )
x = idx_to_labels. values( )
y = pred_softmax. cpu( ) . detach( ) . numpy( ) [ 0 ] * 100
ax2. bar( x, y, alpha= 0.5 , width= 0.3 , color= 'yellow' , edgecolor= 'red' , lw= 3 )
plt. xlabel( '类别' , fontsize= 20 )
plt. ylabel( '置信度' , fontsize= 20 )
ax2. tick_params( labelsize= 16 )
plt. ylim( [ 0 , 100 ] )
plt. xlabel( '类别' , fontsize= 25 )
plt. ylabel( '置信度' , fontsize= 25 )
plt. title( '图像分类预测结果' , fontsize= 30 )
plt. xticks( rotation= 90 )
plt. tight_layout( )
fig. savefig( f' { temp_out_dir} / { frame_id: 06d } .jpg' )
fig. clf( )
plt. close( )
gc. collect( )
temp_out_dir = time. strftime( '%Y%m%d%H%M%S' )
os. mkdir( temp_out_dir)
print ( '创建临时文件夹 {} 用于存放每帧预测结果' . format ( temp_out_dir) )
imgs = mmcv. VideoReader( input_video)
prog_bar = mmcv. ProgressBar( len ( imgs) )
for frame_id, img in enumerate ( imgs) :
img, pred_softmax = pred_single_frame( img, n= 5 )
img = pred_single_frame_bar( img)
prog_bar. update( )
mmcv. frames2video( temp_out_dir, 'output/output_bar.mp4' , fps= imgs. fps, fourcc= 'mp4v' )
shutil. rmtree( temp_out_dir)
print ( '删除临时文件夹' , temp_out_dir)
四、预测实时画面
1. 导入工具包
import os
import numpy as np
import pandas as pd
import cv2
from PIL import Image, ImageFont, ImageDraw
from tqdm import tqdm
import matplotlib. pyplot as plt
% matplotlib inline
import torch
import torch. nn. functional as F
from torchvision import models
device = torch. device( 'cuda:0' if torch. cuda. is_available( ) else 'cpu' )
print ( 'device:' , device)
2. 导入中文字体
font = ImageFont. truetype( 'SimHei.ttf' , 32 )
3. 载入类别&训练好的模型
idx_to_labels = np. load( 'idx_to_labels.npy' , allow_pickle= True ) . item( )
model = torch. load( 'checkpoints/fruit30_pytorch_20220814.pth' , map_location= torch. device( 'cpu' ) )
model = model. eval ( ) . to( device)
4. 图像预处理
from torchvision import transforms
test_transform = transforms. Compose( [ transforms. Resize( 256 ) ,
transforms. CenterCrop( 224 ) ,
transforms. ToTensor( ) ,
transforms. Normalize(
mean= [ 0.485 , 0.456 , 0.406 ] ,
std= [ 0.229 , 0.224 , 0.225 ] )
] )
5. 处理摄像头的一帧画面
import cv2
import time
cap = cv2. VideoCapture( 1 )
cap. open ( 0 )
time. sleep( 1 )
success, img_bgr = cap. read( )
cap. release( )
cv2. destroyAllWindows( )
img_rgb = cv2. cvtColor( img_bgr, cv2. COLOR_BGR2RGB)
img_pil = Image. fromarray( img_rgb)
input_img = test_transform( img_pil) . unsqueeze( 0 ) . to( device)
pred_logits = model( input_img)
pred_softmax = F. softmax( pred_logits, dim= 1 )
n = 5
top_n = torch. topk( pred_softmax, n)
pred_ids = top_n[ 1 ] . cpu( ) . detach( ) . numpy( ) . squeeze( )
confs = top_n[ 0 ] . cpu( ) . detach( ) . numpy( ) . squeeze( )
draw = ImageDraw. Draw( img_pil)
for i in range ( len ( confs) ) :
pred_class = idx_to_labels[ pred_ids[ i] ]
text = '{:<15} {:>.3f}' . format ( pred_class, confs[ i] )
draw. text( ( 50 , 100 + 50 * i) , text, font= font, fill= ( 255 , 0 , 0 , 1 ) )
img = np. array( img_pil)
plt. imshow( img)
plt. show( )
6. 处理单帧画面的函数
def process_frame ( img) :
start_time = time. time( )
img_rgb = cv2. cvtColor( img, cv2. COLOR_BGR2RGB)
img_pil = Image. fromarray( img_rgb)
input_img = test_transform( img_pil) . unsqueeze( 0 ) . to( device)
pred_logits = model( input_img)
pred_softmax = F. softmax( pred_logits, dim= 1 )
top_n = torch. topk( pred_softmax, 5 )
pred_ids = top_n[ 1 ] . cpu( ) . detach( ) . numpy( ) . squeeze( )
confs = top_n[ 0 ] . cpu( ) . detach( ) . numpy( ) . squeeze( )
draw = ImageDraw. Draw( img_pil)
for i in range ( len ( confs) ) :
pred_class = idx_to_labels[ pred_ids[ i] ]
text = '{:<15} {:>.3f}' . format ( pred_class, confs[ i] )
draw. text( ( 50 , 100 + 50 * i) , text, font= font, fill= ( 255 , 0 , 0 , 1 ) )
img = np. array( img_pil)
img = cv2. cvtColor( img, cv2. COLOR_RGB2BGR)
end_time = time. time( )
FPS = 1 / ( end_time - start_time)
img = cv2. putText( img, 'FPS ' + str ( int ( FPS) ) , ( 50 , 80 ) , cv2. FONT_HERSHEY_SIMPLEX, 2 , ( 255 , 0 , 255 ) , 4 , cv2. LINE_AA)
return img
7. 调用摄像头获取每帧(模板)
import cv2
import time
cap = cv2. VideoCapture( 1 )
cap. open ( 0 )
while cap. isOpened( ) :
success, frame = cap. read( )
if not success:
print ( 'Error' )
break
frame = process_frame( frame)
cv2. imshow( 'my_window' , frame)
if cv2. waitKey( 1 ) in [ ord ( 'q' ) , 27 ] :
break
cap. release( )
cv2. destroyAllWindows( )