MMSegmentation笔记06:推理

1. 单张图像预测

"""
==========================================
@author: Seaton
@Time: 2023/8/19:15:38
@IDE: PyCharm
@Summary:使用训练好的模型进行单张图像推理
==========================================
"""

import cv2
import matplotlib.pyplot as plt
import numpy as np
from mmengine import Config

from mmseg.apis import init_model, inference_model

cfg = Config.fromfile('mmsegmentation/Zihao-Configs/ZihaoDataset_UNet_20230712.py')
checkpoint_path = 'mmsegmentation/checkpoint/myUNet.pth'
model = init_model(cfg, checkpoint_path, 'cuda:0')

# 原图
img_path = 'mmsegmentation/Watermelon87_Semantic_Seg_Mask/img_dir/val/01bd15599c606aa801201794e1fa30.jpg'
img_bgr = cv2.imread(img_path)
plt.figure(figsize=(8, 8))
plt.imshow(img_bgr[:, :, ::-1])
plt.show()

# 推理
result = inference_model(model, img_bgr)
pred_mask = result.pred_sem_seg.data[0].cpu().numpy()

# 显示语义分割结果
plt.figure(figsize=(10, 8))
plt.imshow(img_bgr[:, :, ::-1])
plt.imshow(pred_mask, alpha=0.55)  # alpha 高亮区域透明度,越小越接近原图
plt.axis('off')
plt.savefig('mmsegmentation/outputs/K1-1.jpg')
plt.show()

# 各类别的配色方案(BGR)
palette = [
    ['background', [127, 127, 127]],
    ['red', [0, 0, 200]],
    ['green', [0, 200, 0]],
    ['white', [144, 238, 144]],
    ['seed-black', [30, 30, 30]],
    ['seed-white', [8, 189, 251]]
]

palette_dict = {}
for idx, each in enumerate(palette):
    palette_dict[idx] = each[1]
opacity = 0.3  # 透明度,越大越接近原图
# 将预测的整数ID,映射为对应类别的颜色
pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
for idx in palette_dict.keys():
    pred_mask_bgr[np.where(pred_mask == idx)] = palette_dict[idx]
pred_mask_bgr = pred_mask_bgr.astype('uint8')

# 将语义分割预测图和原图叠加显示
pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1 - opacity, 0)

cv2.imwrite('outputs/K1-3.jpg', pred_viz)
plt.figure(figsize=(8, 8))
plt.imshow(pred_viz[:, :, ::-1])
plt.show()

# 对比label和预测结果
label_path = 'mmsegmentation/Watermelon87_Semantic_Seg_Mask/ann_dir/val/01bd15599c606aa801201794e1fa30.png'
label = cv2.imread(label_path)
label_mask = label[:, :, 0]
# 真实为西瓜红瓤,预测为西瓜红壤取并集
TP = (label_mask == 1) & (pred_mask == 1)
plt.imshow(TP)
plt.show()

# 绘制混淆矩阵
from sklearn.metrics import confusion_matrix

confusion_matrix_model = confusion_matrix(label_mask.flatten(), pred_mask.flatten())
import itertools


def cnf_matrix_plotter(cm, classes, cmap=plt.cm.Blues):
    """
    传入混淆矩阵和标签名称列表,绘制混淆矩阵
    """
    plt.figure(figsize=(10, 10))

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    # plt.colorbar() # 色条
    tick_marks = np.arange(len(classes))

    plt.title('Confusion Matrix', fontsize=30)
    plt.xlabel('Pred', fontsize=25, c='r')
    plt.ylabel('True', fontsize=25, c='r')
    plt.tick_params(labelsize=16)  # 设置类别文字大小
    plt.xticks(tick_marks, classes, rotation=90)  # 横轴文字旋转
    plt.yticks(tick_marks, classes)

    # 写数字
    threshold = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > threshold else "black",
                 fontsize=12)

    plt.tight_layout()

    plt.savefig('mmsegmentation/outputs/K1-混淆矩阵.pdf', dpi=300)  # 保存图像
    plt.show()


from mmseg.datasets import ZihaoDataset

classes = ZihaoDataset.METAINFO['classes']
cnf_matrix_plotter(confusion_matrix_model, classes, cmap='Blues')

本节的代码整理如上,基本是对子豪兄的代码进行路径上的修改,也就是在路径最前面加mmsegmentation/

没什么可展开讲的,主要流程可以总结如下:

  • 定义config文件和pth文件的路径

  • 基于config文件和pth文件通过init_model函数建立模型

  • 各种方法来绘制原图与结果

  • 绘制混淆矩阵

2. 视频预测

"""
==========================================
@author: Seaton
@Time: 2023/8/20:16:56
@IDE: PyCharm
@Summary:使用训练好的模型进行单张图像推理
==========================================
"""
import time
import numpy as np
from tqdm import tqdm
import cv2

import mmcv
from mmseg.apis import init_model, inference_model

config_file = 'mmsegmentation/Zihao-Configs/ZihaoDataset_UNet_20230712.py'
checkpoint_file = 'mmsegmentation/checkpoint/myUNet.pth'

from mmseg.apis import init_model

model = init_model(config_file, checkpoint_file, device='cuda:0')

palette = [
    ['background', [127, 127, 127]],
    ['red', [0, 0, 200]],
    ['green', [0, 200, 0]],
    ['white', [144, 238, 144]],
    ['seed-black', [30, 30, 30]],
    ['seed-white', [8, 189, 251]]
]
palette_dict = {}
for idx, each in enumerate(palette):
    palette_dict[idx] = each[1]

opacity = 0.3  # 透明度,越大越接近原图


# 逐帧处理函数
def process_frame(img_bgr):
    # 记录该帧开始处理的时间
    start_time = time.time()

    # 语义分割预测
    result = inference_model(model, img_bgr)
    pred_mask = result.pred_sem_seg.data[0].cpu().numpy()

    # 将预测的整数ID,映射为对应类别的颜色
    pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
    for idx in palette_dict.keys():
        pred_mask_bgr[np.where(pred_mask == idx)] = palette_dict[idx]
    pred_mask_bgr = pred_mask_bgr.astype('uint8')

    # 将语义分割预测图和原图叠加显示
    pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1 - opacity, 0)

    return pred_viz


# 视频逐帧处理代码模板
# 不需修改任何代码,只需定义process_frame函数即可
# 同济子豪兄 2021-7-10

def generate_video(input_path='videos/robot.mp4'):
    filehead = input_path.split('/')[-1]
    output_path = "out-" + filehead

    print('视频开始处理', input_path)

    # 获取视频总帧数
    cap = cv2.VideoCapture(input_path)
    frame_count = 0
    while (cap.isOpened()):
        success, frame = cap.read()
        frame_count += 1
        if not success:
            break
    cap.release()
    print('视频总帧数为', frame_count)

    # cv2.namedWindow('Crack Detection and Measurement Video Processing')
    cap = cv2.VideoCapture(input_path)
    frame_size = (cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))
    # fourcc = cv2.VideoWriter_fourcc(*'XVID')
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = cap.get(cv2.CAP_PROP_FPS)

    out = cv2.VideoWriter(output_path, fourcc, fps, (int(frame_size[0]), int(frame_size[1])))

    # 进度条绑定视频总帧数
    with tqdm(total=frame_count - 1) as pbar:
        try:
            while (cap.isOpened()):
                success, frame = cap.read()
                if not success:
                    break

                # 处理帧
                # frame_path = './temp_frame.png'
                # cv2.imwrite(frame_path, frame)
                try:
                    frame = process_frame(frame)
                except:
                    # print('报错!', error)
                    pass

                if success == True:
                    # cv2.imshow('Video Processing', frame)
                    out.write(frame)

                    # 进度条更新一帧
                    pbar.update(1)

                # if cv2.waitKey(1) & 0xFF == ord('q'):
                # break
        except:
            print('中途中断')
            pass

    cv2.destroyAllWindows()
    out.release()
    cap.release()
    print('视频已保存', output_path)


generate_video(input_path='demo/test.mp4')

本节整理代码如上,基本原理与单张预测几乎一样,多了一步就是将视频拆成单帧,进行预测后再拼合成视频并保存。

3. 整个文件夹图片预测

"""
==========================================
@author: Seaton
@Time: 2023/8/20:18:37
@IDE: PyCharm
@Summary:使用训练好的模型进行文件夹下所有图像推理
==========================================
"""
import os
import numpy as np
import cv2
from tqdm import tqdm

from mmseg.apis import init_model, inference_model, show_result_pyplot
import mmcv

import matplotlib.pyplot as plt

# 模型 config 配置文件
config_file = 'mmsegmentation/Zihao-Configs/ZihaoDataset_UNet_20230712.py'
# 模型权重文件
checkpoint_file = 'mmsegmentation/checkpoint/myUNet.pth'

# 计算硬件
device = 'cuda:0'

model = init_model(config_file, checkpoint_file, device=device)

# 每个类别的 BGR 配色
palette = [
    ['background', [127, 127, 127]],
    ['red', [0, 0, 200]],
    ['green', [0, 200, 0]],
    ['white', [144, 238, 144]],
    ['seed-black', [30, 30, 30]],
    ['seed-white', [8, 189, 251]]
]

palette_dict = {}
for idx, each in enumerate(palette):
    palette_dict[idx] = each[1]

if not os.path.exists('mmsegmentation/outputs/testset-pred'):
    os.mkdir('mmsegmentation/outputs/testset-pred')

PATH_IMAGE = 'mmsegmentation/Watermelon87_Semantic_Seg_Mask/img_dir/val'
opacity = 0.3  # 透明度,越大越接近原图


def process_single_img(img_path, save=False):
    img_bgr = cv2.imread(img_path)

    # 语义分割预测
    result = inference_model(model, img_bgr)
    pred_mask = result.pred_sem_seg.data[0].cpu().numpy()

    # 将预测的整数ID,映射为对应类别的颜色
    pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
    for idx in palette_dict.keys():
        pred_mask_bgr[np.where(pred_mask == idx)] = palette_dict[idx]
    pred_mask_bgr = pred_mask_bgr.astype('uint8')

    # 将语义分割预测图和原图叠加显示
    pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1 - opacity, 0)

    # 保存图像至 outputs/testset-pred 目录
    if save:
        save_path = os.path.join('../', '../', '../', 'outputs', 'testset-pred', 'pred-' + img_path.split('/')[-1])
        cv2.imwrite(save_path, pred_viz)
        print('已保存')


os.chdir(PATH_IMAGE)
# for each in tqdm(os.listdir()):
# process_single_img(each, save=True)


# 批量可视化
os.chdir('../../../outputs/testset-pred')
# n 行 n 列可视化
n = 4

fig, axes = plt.subplots(nrows=n, ncols=n, figsize=(16, 10))

for i, file_name in enumerate(os.listdir()[:n ** 2]):
    img_bgr = cv2.imread(file_name)

    # 可视化
    axes[i // n, i % n].imshow(img_bgr[:, :, ::-1])
    axes[i // n, i % n].axis('off')  # 关闭坐标轴显示
fig.suptitle('Semantic Segmentation Predictions', fontsize=30)
# plt.tight_layout()
plt.savefig('../K3.jpg')
plt.show()

本节也是照猫画虎,终点在于os库的应用,官方代码有一处需要修改,即79行,将os.chdir('outputs/testset-pred')修改为os.chdir('../../../outputs/testset-pred')

  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值