1. 单张图像预测
"""
==========================================
@author: Seaton
@Time: 2023/8/19:15:38
@IDE: PyCharm
@Summary:使用训练好的模型进行单张图像推理
==========================================
"""
import cv2
import matplotlib.pyplot as plt
import numpy as np
from mmengine import Config
from mmseg.apis import init_model, inference_model
cfg = Config.fromfile('mmsegmentation/Zihao-Configs/ZihaoDataset_UNet_20230712.py')
checkpoint_path = 'mmsegmentation/checkpoint/myUNet.pth'
model = init_model(cfg, checkpoint_path, 'cuda:0')
# 原图
img_path = 'mmsegmentation/Watermelon87_Semantic_Seg_Mask/img_dir/val/01bd15599c606aa801201794e1fa30.jpg'
img_bgr = cv2.imread(img_path)
plt.figure(figsize=(8, 8))
plt.imshow(img_bgr[:, :, ::-1])
plt.show()
# 推理
result = inference_model(model, img_bgr)
pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
# 显示语义分割结果
plt.figure(figsize=(10, 8))
plt.imshow(img_bgr[:, :, ::-1])
plt.imshow(pred_mask, alpha=0.55) # alpha 高亮区域透明度,越小越接近原图
plt.axis('off')
plt.savefig('mmsegmentation/outputs/K1-1.jpg')
plt.show()
# 各类别的配色方案(BGR)
palette = [
['background', [127, 127, 127]],
['red', [0, 0, 200]],
['green', [0, 200, 0]],
['white', [144, 238, 144]],
['seed-black', [30, 30, 30]],
['seed-white', [8, 189, 251]]
]
palette_dict = {}
for idx, each in enumerate(palette):
palette_dict[idx] = each[1]
opacity = 0.3 # 透明度,越大越接近原图
# 将预测的整数ID,映射为对应类别的颜色
pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
for idx in palette_dict.keys():
pred_mask_bgr[np.where(pred_mask == idx)] = palette_dict[idx]
pred_mask_bgr = pred_mask_bgr.astype('uint8')
# 将语义分割预测图和原图叠加显示
pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1 - opacity, 0)
cv2.imwrite('outputs/K1-3.jpg', pred_viz)
plt.figure(figsize=(8, 8))
plt.imshow(pred_viz[:, :, ::-1])
plt.show()
# 对比label和预测结果
label_path = 'mmsegmentation/Watermelon87_Semantic_Seg_Mask/ann_dir/val/01bd15599c606aa801201794e1fa30.png'
label = cv2.imread(label_path)
label_mask = label[:, :, 0]
# 真实为西瓜红瓤,预测为西瓜红壤取并集
TP = (label_mask == 1) & (pred_mask == 1)
plt.imshow(TP)
plt.show()
# 绘制混淆矩阵
from sklearn.metrics import confusion_matrix
confusion_matrix_model = confusion_matrix(label_mask.flatten(), pred_mask.flatten())
import itertools
def cnf_matrix_plotter(cm, classes, cmap=plt.cm.Blues):
"""
传入混淆矩阵和标签名称列表,绘制混淆矩阵
"""
plt.figure(figsize=(10, 10))
plt.imshow(cm, interpolation='nearest', cmap=cmap)
# plt.colorbar() # 色条
tick_marks = np.arange(len(classes))
plt.title('Confusion Matrix', fontsize=30)
plt.xlabel('Pred', fontsize=25, c='r')
plt.ylabel('True', fontsize=25, c='r')
plt.tick_params(labelsize=16) # 设置类别文字大小
plt.xticks(tick_marks, classes, rotation=90) # 横轴文字旋转
plt.yticks(tick_marks, classes)
# 写数字
threshold = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if cm[i, j] > threshold else "black",
fontsize=12)
plt.tight_layout()
plt.savefig('mmsegmentation/outputs/K1-混淆矩阵.pdf', dpi=300) # 保存图像
plt.show()
from mmseg.datasets import ZihaoDataset
classes = ZihaoDataset.METAINFO['classes']
cnf_matrix_plotter(confusion_matrix_model, classes, cmap='Blues')
本节的代码整理如上,基本是对子豪兄的代码进行路径上的修改,也就是在路径最前面加mmsegmentation/
。
没什么可展开讲的,主要流程可以总结如下:
-
定义config文件和pth文件的路径
-
基于config文件和pth文件通过
init_model
函数建立模型 -
各种方法来绘制原图与结果
-
绘制混淆矩阵
2. 视频预测
"""
==========================================
@author: Seaton
@Time: 2023/8/20:16:56
@IDE: PyCharm
@Summary:使用训练好的模型进行单张图像推理
==========================================
"""
import time
import numpy as np
from tqdm import tqdm
import cv2
import mmcv
from mmseg.apis import init_model, inference_model
config_file = 'mmsegmentation/Zihao-Configs/ZihaoDataset_UNet_20230712.py'
checkpoint_file = 'mmsegmentation/checkpoint/myUNet.pth'
from mmseg.apis import init_model
model = init_model(config_file, checkpoint_file, device='cuda:0')
palette = [
['background', [127, 127, 127]],
['red', [0, 0, 200]],
['green', [0, 200, 0]],
['white', [144, 238, 144]],
['seed-black', [30, 30, 30]],
['seed-white', [8, 189, 251]]
]
palette_dict = {}
for idx, each in enumerate(palette):
palette_dict[idx] = each[1]
opacity = 0.3 # 透明度,越大越接近原图
# 逐帧处理函数
def process_frame(img_bgr):
# 记录该帧开始处理的时间
start_time = time.time()
# 语义分割预测
result = inference_model(model, img_bgr)
pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
# 将预测的整数ID,映射为对应类别的颜色
pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
for idx in palette_dict.keys():
pred_mask_bgr[np.where(pred_mask == idx)] = palette_dict[idx]
pred_mask_bgr = pred_mask_bgr.astype('uint8')
# 将语义分割预测图和原图叠加显示
pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1 - opacity, 0)
return pred_viz
# 视频逐帧处理代码模板
# 不需修改任何代码,只需定义process_frame函数即可
# 同济子豪兄 2021-7-10
def generate_video(input_path='videos/robot.mp4'):
filehead = input_path.split('/')[-1]
output_path = "out-" + filehead
print('视频开始处理', input_path)
# 获取视频总帧数
cap = cv2.VideoCapture(input_path)
frame_count = 0
while (cap.isOpened()):
success, frame = cap.read()
frame_count += 1
if not success:
break
cap.release()
print('视频总帧数为', frame_count)
# cv2.namedWindow('Crack Detection and Measurement Video Processing')
cap = cv2.VideoCapture(input_path)
frame_size = (cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))
# fourcc = cv2.VideoWriter_fourcc(*'XVID')
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fps = cap.get(cv2.CAP_PROP_FPS)
out = cv2.VideoWriter(output_path, fourcc, fps, (int(frame_size[0]), int(frame_size[1])))
# 进度条绑定视频总帧数
with tqdm(total=frame_count - 1) as pbar:
try:
while (cap.isOpened()):
success, frame = cap.read()
if not success:
break
# 处理帧
# frame_path = './temp_frame.png'
# cv2.imwrite(frame_path, frame)
try:
frame = process_frame(frame)
except:
# print('报错!', error)
pass
if success == True:
# cv2.imshow('Video Processing', frame)
out.write(frame)
# 进度条更新一帧
pbar.update(1)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
except:
print('中途中断')
pass
cv2.destroyAllWindows()
out.release()
cap.release()
print('视频已保存', output_path)
generate_video(input_path='demo/test.mp4')
本节整理代码如上,基本原理与单张预测几乎一样,多了一步就是将视频拆成单帧,进行预测后再拼合成视频并保存。
3. 整个文件夹图片预测
"""
==========================================
@author: Seaton
@Time: 2023/8/20:18:37
@IDE: PyCharm
@Summary:使用训练好的模型进行文件夹下所有图像推理
==========================================
"""
import os
import numpy as np
import cv2
from tqdm import tqdm
from mmseg.apis import init_model, inference_model, show_result_pyplot
import mmcv
import matplotlib.pyplot as plt
# 模型 config 配置文件
config_file = 'mmsegmentation/Zihao-Configs/ZihaoDataset_UNet_20230712.py'
# 模型权重文件
checkpoint_file = 'mmsegmentation/checkpoint/myUNet.pth'
# 计算硬件
device = 'cuda:0'
model = init_model(config_file, checkpoint_file, device=device)
# 每个类别的 BGR 配色
palette = [
['background', [127, 127, 127]],
['red', [0, 0, 200]],
['green', [0, 200, 0]],
['white', [144, 238, 144]],
['seed-black', [30, 30, 30]],
['seed-white', [8, 189, 251]]
]
palette_dict = {}
for idx, each in enumerate(palette):
palette_dict[idx] = each[1]
if not os.path.exists('mmsegmentation/outputs/testset-pred'):
os.mkdir('mmsegmentation/outputs/testset-pred')
PATH_IMAGE = 'mmsegmentation/Watermelon87_Semantic_Seg_Mask/img_dir/val'
opacity = 0.3 # 透明度,越大越接近原图
def process_single_img(img_path, save=False):
img_bgr = cv2.imread(img_path)
# 语义分割预测
result = inference_model(model, img_bgr)
pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
# 将预测的整数ID,映射为对应类别的颜色
pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
for idx in palette_dict.keys():
pred_mask_bgr[np.where(pred_mask == idx)] = palette_dict[idx]
pred_mask_bgr = pred_mask_bgr.astype('uint8')
# 将语义分割预测图和原图叠加显示
pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1 - opacity, 0)
# 保存图像至 outputs/testset-pred 目录
if save:
save_path = os.path.join('../', '../', '../', 'outputs', 'testset-pred', 'pred-' + img_path.split('/')[-1])
cv2.imwrite(save_path, pred_viz)
print('已保存')
os.chdir(PATH_IMAGE)
# for each in tqdm(os.listdir()):
# process_single_img(each, save=True)
# 批量可视化
os.chdir('../../../outputs/testset-pred')
# n 行 n 列可视化
n = 4
fig, axes = plt.subplots(nrows=n, ncols=n, figsize=(16, 10))
for i, file_name in enumerate(os.listdir()[:n ** 2]):
img_bgr = cv2.imread(file_name)
# 可视化
axes[i // n, i % n].imshow(img_bgr[:, :, ::-1])
axes[i // n, i % n].axis('off') # 关闭坐标轴显示
fig.suptitle('Semantic Segmentation Predictions', fontsize=30)
# plt.tight_layout()
plt.savefig('../K3.jpg')
plt.show()
本节也是照猫画虎,终点在于os库的应用,官方代码有一处需要修改,即79行,将os.chdir('outputs/testset-pred')
修改为os.chdir('../../../outputs/testset-pred')
。