目录
前言
(62条消息) 手把手教你如何配置mmdetection并训练(win10环境)_mmdetection 环境配置_PolarisFTL的博客-CSDN博客上一篇博客教大家如何配置mmdetection工具箱以及模型如何训练,今天这篇博客教大家如何使用已经训练好的模型进行评估以及图片预测
一、mmdetection训练结果可视化
这里以centernet网络为例,我已经事先将configs中对其中一个配置文件的名称修改为centernet.py方便寻找,训练开始时会在work_dirs中生成对应的文件夹,这里生成了centernet文件夹,打开之后可以看到模型训练相关文件
打开之后可以看到文件夹中的文件命名格式以训练时间为前缀,文件夹中还存在一个centernet.py文件,这个文件是模型开始训练时生成的一个完整的配置文件,大家可以打开自行查看,这里文件不同于原本配置文件是因为配置文件大部分都是需要调用其他功能模块来实现,而训练生成的配置文件是一个完整的,不需要再去调用其他模块,对之后测试比较方便
下面对相关文件进行详细描述:
20230714_210738.log:这是模型训练时的日志文件,包含了网络结构、训练时间、学习率、loss等,大家如果需要查看模型结构可以通过这个文件进行查看。
20230714_210738.json:这里是规范化之后的日志文件,可以通过这个文件来生成模型训练的图像,损失函数图像以及mAP曲线
下面代码为模型训练结果可视化代码,源代码来源于网络,在此基础上进行了部分修改,大家使用可视化代码时要注意json文件位置,然后根据自己训练模型的loss进行修改,可以参考json文件中的参数进行修改
import json
import matplotlib.pyplot as plt
from collections import OrderedDict
class visualize_mmdetection():
def __init__(self, path, *args):
self.log = open(path)
self.dict_list = list()
self.AP_list = list()
self.loss_dict = {}
self.ap_dict = {}
self.outname = path.split('/')[-1].split('.')[0]
for i in args:
if 'AP' in i:
self.ap_dict[i] = list()
else:
self.loss_dict[i] = list()
def load_data(self):
for row, line in enumerate(self.log):
if 'mAP' not in line:
info = json.loads(line)
self.dict_list.append(info)
if 'mAP' in line:
info = json.loads(line)
self.AP_list.append(info)
for i in range(1, len(self.dict_list)):
for key in self.loss_dict.keys():
self.loss_dict[key].append(dict(self.dict_list[i])[key])
for i in range(1, len(self.AP_list)):
for key in self.ap_dict.keys():
self.ap_dict[key].append(dict(self.AP_list[i])[key])
for key in self.loss_dict.keys():
self.loss_dict[key] = list(OrderedDict.fromkeys(self.loss_dict[key]))
for key in self.ap_dict.keys():
self.ap_dict[key] = list(OrderedDict.fromkeys(self.ap_dict[key]))
def show_chart(self):
plt.rcParams.update({'font.size': 15})
plt.figure(figsize=(30, 50))
num = len(self.loss_dict.keys()) + len(self.ap_dict.keys())
col = 2
import math
line = math.ceil(num / col)
ind = 0
for key in self.loss_dict.keys():
ind += 1
plt.subplot(line, col, ind, title=key, ylabel='Loss')
plt.xlabel('Step')
plt.plot(self.loss_dict[key])
for key in self.ap_dict.keys():
ind += 1
plt.subplot(line, col, ind, title=key, ylabel='mAP')
plt.xlabel('Epoch')
plt.plot(self.ap_dict[key])
plt.suptitle((self.outname + "\n training result"), fontsize=30)
plt.savefig((self.outname + '_result.png'))
if __name__ == '__main__':
x = visualize_mmdetection('../centernet/20230714_210738/vis_data/20230714_210738.json',
'loss_cls', 'loss_bbox', 'loss_centerness', 'loss', 'lr', 'pascal_voc/mAP')
# x = visualize_mmdetection('vis_data/20230712_202526.json',
# 'loss_rpn_cls', 'loss_rpn_bbox', 'loss_cls', 'loss_bbox', 'loss',
# 'coco/bbox_mAP', 'coco/bbox_mAP_50', 'coco/bbox_mAP_75', 'coco/bbox_mAP_s',
# 'coco/bbox_mAP_m', 'coco/bbox_mAP_l',)
x.load_data()
x.show_chart()
训练结果可视化展示,会生成每一个loss的曲线图以及mAP曲线图
二、mmdetection模型预测
首先,在mmdetection目录下新建一个文件夹images,将需要预测的图片放入文件夹中
然后,在mmdetection目录下新建一个py文件命名为predict.py,具体代码如下:
config_file = 'work_dirs/centernet/centernet.py'
checkpoint_file = 'work_dirs/centernet/epoch_25.pth'
这两个文件分别是配置文件和训练好的模型的地址,大家根据自己的文件目录进行设置,直接运行代码即可,预测图片会存放在images目录下
import os
import cv2
import mmcv
from mmdet.registry import VISUALIZERS
from mmdet.apis import init_detector, inference_detector
# 指定模型的配置文件和 checkpoint 文件路径
config_file = 'work_dirs/centernet/centernet.py'
checkpoint_file = 'work_dirs/centernet/epoch_25.pth'
model = init_detector(config_file, checkpoint_file, device='cuda:0')
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta
def detect_image(model, visualizer, img_path):
img = mmcv.imread(img_path)
result = inference_detector(model, img)
img = mmcv.imconvert(img, 'bgr', 'rgb')
visualizer.add_datasample(
'result',
img,
data_sample=result,
draw_gt=False,
show=False)
img_with_bbox = visualizer.get_image()
save_path = img_path.replace('.jpg', '_result.jpg')
cv2.imwrite(save_path, img_with_bbox[:, :, ::-1])
def detect_images_in_folder(model, visualizer, folder_path):
image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.jpg')]
for img_path in image_paths:
detect_image(model, visualizer, img_path)
def main():
# 图片预测
folder_path = 'images/'
detect_images_in_folder(model, visualizer, folder_path)
if __name__ == '__main__':
main()
总结
制作不易,喜欢的小伙伴可以收藏点赞,有什么问题都可以在评论区留言