1.训练并保存自己的模型
保存的模型格式为:XXX.pth
torch.save(model, "./weight/last.pth")
if best_acc <(validation_acc / len_val):
torch.save(model, "./weight/best.pth")
2.转化为ONNX格式
2.1环境安装(window10)
pip install onnx
pip install onnxruntime
#验证安装配置是否成功
import torch
print('PyTorch 版本', torch.__version__)
import onnx
print('ONNX 版本', onnx.__version__)
import onnxruntime as ort
print('ONNX Runtime 版本', ort.__version__)
2.2.pth格式转ONNX格式
import torch
from torchvision import models
# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)
model = torch.load('best.pth')
model = model.eval().to(device)
x = torch.randn(1, 3, 256, 256).to(device) #这里要构造一个数据,保证和自己输入的图片大小一致3*256*256
output = model(x) #output.shape = torch.Size([1, 10]) 这是一个10分类问题
#Pytorch模型转ONNX模型
x = torch.randn(1, 3, 256, 256).to(device)
with torch.no_grad():
torch.onnx.export(
model, # 要转换的模型
x, # 模型的任意一组输入
'best.onnx', # 导出的 ONNX 文件名
opset_version=11, # ONNX 算子集版本
input_names=['input'], # 输入 Tensor 的名称(自己起名字)
output_names=['output'] # 输出 Tensor 的名称(自己起名字)
import onnx
# 读取 ONNX 模型
onnx_model = onnx.load('resnet18_fruit30.onnx')
# 检查模型格式是否正确
onnx.checker.check_model(onnx_model)
print('无报错,onnx模型载入成功')
这是project中就出现了“best.onnx”文件,表示转化ONNX格式成功!
3.可视化实时检测
3.1在PC电脑端查看
3.1.1环境安装(待补充)
pip install onnxruntime
需要提前保存一个类别ID和类别名称对应的文件
3.1.2 摄像头实时捕捉分类处理
import os
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm # 进度条
import torch
import torch.nn.functional as F
from torchvision import transforms
import time
import onnxruntime
from PIL import Image, ImageFont, ImageDraw
import matplotlib.pyplot as plt
# 导入中文字体,指定字体大小
font = ImageFont.truetype('SimHei.ttf', 32)
ort_session = onnxruntime.InferenceSession('resnet18_imagenet.onnx')
# 载入ImageNet 1000图像分类标签
df = pd.read_csv('imagenet_class_index.csv')
idx_to_labels = {}
for idx, row in df.iterrows():
idx_to_labels[row['ID']] = row['Chinese']
# 测试集图像预处理-RCTN:缩放裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 处理帧函数
def process_frame(img_bgr):
# 记录该帧开始处理的时间
start_time = time.time()
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) # BGR转RGB
img_pil = Image.fromarray(img_rgb) # array 转 PIL
## 预处理
input_img = test_transform(img_pil) # 预处理
input_tensor = input_img.unsqueeze(0).numpy()
## onnx runtime 预测
ort_inputs = {'input': input_tensor} # onnx runtime 输入
pred_logits = ort_session.run(['output'], ort_inputs)[0] # onnx runtime 输出
pred_logits = torch.tensor(pred_logits)
pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算
## 解析图像分类预测结果
n = 5
top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
confs = top_n[0].cpu().detach().numpy().squeeze() # 解析出置信度
## 在图像上写中文
draw = ImageDraw.Draw(img_pil)
for i in range(len(confs)):
pred_class = idx_to_labels[pred_ids[i]]
# 写中文:文字坐标,中文字符串,字体,rgba颜色
text = '{:<15} {:>.3f}'.format(pred_class, confs[i]) # 中文字符串
draw.text((50, 100 + 50 * i), text, font=font, fill=(255, 0, 0, 1))
img_rgb = np.array(img_pil) # PIL 转 array
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) # RGB转BGR
# 记录该帧处理完毕的时间
end_time = time.time()
# 计算每秒处理图像帧数FPS
FPS = 1 / (end_time - start_time)
# 图片,添加的文字,左上角坐标,字体,字体大小,颜色,线宽,线型
img_bgr = cv2.putText(img_bgr, 'FPS ' + str(int(FPS)), (50, 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 255), 4,
cv2.LINE_AA)
return img_bgr
# 获取摄像头,传入0表示获取系统默认摄像头
cap = cv2.VideoCapture(1)
# 打开cap
cap.open(0)
# 无限循环,直到break被触发
while cap.isOpened():
# 获取画面
success, frame = cap.read()
if not success: # 如果获取画面不成功,则退出
print('获取画面不成功,退出')
break
## 逐帧处理
frame = process_frame(frame)
# 展示处理后的三通道图像
cv2.imshow('my_window', frame)
key_pressed = cv2.waitKey(60) # 每隔多少毫秒毫秒,获取键盘哪个键被按下
# print('键盘上被按下的键:', key_pressed)
if key_pressed in [ord('q'), 27]: # 按键盘上的q或esc退出(在英文输入法下)
break
# 关闭摄像头
cap.release()
# 关闭图像窗口
cv2.destroyAllWindows()
使用q按钮退出!
3.1.3 视频离线分类处理
import os
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm # 进度条
import torch
import torch.nn.functional as F
from torchvision import transforms
import time
import onnxruntime
from PIL import Image, ImageFont, ImageDraw
import matplotlib.pyplot as plt
# 导入中文字体,指定字体大小
font = ImageFont.truetype('SimHei.ttf', 32)
ort_session = onnxruntime.InferenceSession('resnet18_imagenet.onnx')
# 载入ImageNet 1000图像分类标签
df = pd.read_csv('imagenet_class_index.csv')
idx_to_labels = {}
for idx, row in df.iterrows():
idx_to_labels[row['ID']] = row['Chinese']
# 测试集图像预处理-RCTN:缩放裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
import cv2
import numpy as np
import time
from tqdm import tqdm
# 处理帧函数
def process_frame(img_bgr):
# 记录该帧开始处理的时间
start_time = time.time()
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) # BGR转RGB
img_pil = Image.fromarray(img_rgb) # array 转 PIL
## 预处理
input_img = test_transform(img_pil) # 预处理
input_tensor = input_img.unsqueeze(0).numpy()
## onnx runtime 预测
ort_inputs = {'input': input_tensor} # onnx runtime 输入
pred_logits = ort_session.run(['output'], ort_inputs)[0] # onnx runtime 输出
pred_logits = torch.tensor(pred_logits)
pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算
## 解析图像分类预测结果
n = 5
top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出类别
confs = top_n[0].cpu().detach().numpy().squeeze() # 解析出置信度
## 在图像上写中文
draw = ImageDraw.Draw(img_pil)
for i in range(len(confs)):
pred_class = idx_to_labels[pred_ids[i]]
# 写中文:文字坐标,中文字符串,字体,rgba颜色
text = '{:<15} {:>.3f}'.format(pred_class, confs[i]) # 中文字符串
draw.text((50, 100 + 50 * i), text, font=font, fill=(255, 0, 0, 1))
img_rgb = np.array(img_pil) # PIL 转 array
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) # RGB转BGR
# 记录该帧处理完毕的时间
end_time = time.time()
# 计算每秒处理图像帧数FPS
FPS = 1 / (end_time - start_time)
# 图片,添加的文字,左上角坐标,字体,字体大小,颜色,线宽,线型
img_bgr = cv2.putText(img_bgr, 'FPS ' + str(int(FPS)), (50, 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 255), 4,
cv2.LINE_AA)
return img_bgr
# 视频逐帧处理代码模板
# 不需修改任何代码,只需定义process_frame函数即可
# 同济子豪兄 2021-7-10
def generate_video(input_path='videos/robot.mp4'):
filehead = input_path.split('/')[-1]
output_path = "out-" + filehead
print('视频开始处理', input_path)
# 获取视频总帧数
cap = cv2.VideoCapture(input_path)
frame_count = 0
while (cap.isOpened()):
success, frame = cap.read()
frame_count += 1
if not success:
break
cap.release()
print('视频总帧数为', frame_count)
# cv2.namedWindow('Crack Detection and Measurement Video Processing')
cap = cv2.VideoCapture(input_path)
frame_size = (cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))
# fourcc = cv2.VideoWriter_fourcc(*'XVID')
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fps = cap.get(cv2.CAP_PROP_FPS)
out = cv2.VideoWriter(output_path, fourcc, fps, (int(frame_size[0]), int(frame_size[1])))
# 进度条绑定视频总帧数
with tqdm(total=frame_count - 1) as pbar:
try:
while (cap.isOpened()):
success, frame = cap.read()
if not success:
break
# 处理帧
# frame_path = './temp_frame.png'
# cv2.imwrite(frame_path, frame)
try:
frame = process_frame(frame)
except:
print('报错!', error)
pass
if success == True:
# cv2.imshow('Video Processing', frame)
out.write(frame)
# 进度条更新一帧
pbar.update(1)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
except:
print('中途中断')
pass
cv2.destroyAllWindows()
out.release()
cap.release()
print('视频已保存', output_path)
generate_video(input_path='video_4.mp4')
参考:https://www.bilibili.com/video/BV1AM4y187yR/?spm_id_from=333.788&vd_source=47e66af6a90e9c41c341fd3c692ced14