如何使用yolov8训练使用 电力设备红外类 开关设备红外过热图像数据集,5500张图片,标注为voc(xml)格式,连接部分,主体,负荷开关,避雷器,电流互感器,电压互感器,实现评估及可视化模型

如何训练这套数据集呢?

开关设备红外过热图像数据集,总共5500左右张图片,标注为voc(xml)格式,总共8类,分别为核心,连接部分,主体,负荷开关,避雷器,电流互感器,电压互感器,塑料外壳式断路器在这里插入图片描述
开关设备红外过热图像数据集,总共5500左右张图片,标注为voc(xml)格式,
总共8类,
在这里插入图片描述

分别为核心,连接部分,主体,负荷开关,避雷器,电流互感器,电压互感器,塑料外壳式断路器在这里插入图片描述
在这里插入图片描述

完整项目结构

switch_device_overheat_detection/
├── main.py
├── train.py
├── evaluate.py
├── infer.py
├── datasets/
│   ├── switch_devices/
│   │   ├── Annotations/
│   │   ├── ImageSets/
│   │   │   └── Main/
│   │   │       ├── train.txt
│   │   │       └── val.txt
│   │   └── JPEGImages/
├── best_switch_device.pt
├── requirements.txt
└── data.yaml

文件内容

requirements.txt
opencv-python
torch==1.9
ultralytics
PyQt5
data.yaml
train: ./datasets/switch_devices/JPEGImages/train
val: ./datasets/switch_devices/JPEGImages/val
test: ./datasets/switch_devices/JPEGImages/test

nc: 8
names: ['core', 'connection_part', 'body', 'load_switch', 'arrester', 'current_transformer', 'voltage_transformer', 'plastic_enclosure_circuit_breaker']
convert_voc_to_yolo.py
import os
import xml.etree.ElementTree as ET
import shutil
import cv2

def xml_to_yolo(xml_file, image_width, image_height):
    yolo_labels = []
    tree = ET.parse(xml_file)
    root = tree.getroot()
    for obj in root.findall('object'):
        label = obj.find('name').text
        bbox = obj.find('bndbox')
        xmin = int(bbox.find('xmin').text)
        ymin = int(bbox.find('ymin').text)
        xmax = int(bbox.find('xmax').text)
        ymax = int(bbox.find('ymax').text)
        
        x_center = (xmin + xmax) / 2.0 / image_width
        y_center = (ymin + ymax) / 2.0 / image_height
        width = (xmax - xmin) / image_width
        height = (ymax - ymin) / image_height
        
        class_id = {
            'core': 0,
            'connection_part': 1,
            'body': 2,
            'load_switch': 3,
            'arrester': 4,
            'current_transformer': 5,
            'voltage_transformer': 6,
            'plastic_enclosure_circuit_breaker': 7
        }[label]
        yolo_labels.append(f"{class_id} {x_center} {y_center} {width} {height}")
    
    return '\n'.join(yolo_labels)

def split_dataset(image_dir, annotations_dir, output_dir, train_ratio=0.8):
    images = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
    num_train = int(len(images) * train_ratio)
    train_images = images[:num_train]
    val_images = images[num_train:]
    
    with open(os.path.join(output_dir, 'ImageSets/Main/train.txt'), 'w') as f:
        f.write('\n'.join([os.path.splitext(img)[0] for img in train_images]))
    
    with open(os.path.join(output_dir, 'ImageSets/Main/val.txt'), 'w') as f:
        f.write('\n'.join([os.path.splitext(img)[0] for img in val_images]))

def convert_dataset(voc_dir, yolo_dir):
    annotations_dir = os.path.join(voc_dir, 'Annotations')
    images_dir = os.path.join(voc_dir, 'JPEGImages')
    yolo_labels_dir = os.path.join(yolo_dir, 'labels')
    os.makedirs(yolo_labels_dir, exist_ok=True)
    os.makedirs(os.path.join(yolo_dir, 'images'), exist_ok=True)
    os.makedirs(os.path.join(yolo_dir, 'images/train'), exist_ok=True)
    os.makedirs(os.path.join(yolo_dir, 'images/val'), exist_ok=True)
    os.makedirs(os.path.join(yolo_dir, 'ImageSets/Main'), exist_ok=True)
    
    split_dataset(images_dir, annotations_dir, yolo_dir)
    
    for filename in os.listdir(annotations_dir):
        if filename.endswith('.xml'):
            xml_file = os.path.join(annotations_dir, filename)
            image_filename = os.path.splitext(filename)[0] + '.jpg'
            image_path = os.path.join(images_dir, image_filename)
            image = cv2.imread(image_path)
            image_height, image_width, _ = image.shape
            yolo_label = xml_to_yolo(xml_file, image_width, image_height)
            txt_filename = os.path.splitext(filename)[0] + '.txt'
            txt_file = os.path.join(yolo_labels_dir, txt_filename)
            with open(txt_file, 'w') as f:
                f.write(yolo_label)
            
            # Copy image to YOLO directory
            base_image_dir = os.path.join(yolo_dir, 'images')
            if image_filename.split('.')[0] in [line.strip() for line in open(os.path.join(yolo_dir, 'ImageSets/Main/train.txt'))]:
                target_image_dir = os.path.join(base_image_dir, 'train')
            else:
                target_image_dir = os.path.join(base_image_dir, 'val')
            shutil.copy(image_path, target_image_dir)

# 示例用法
convert_dataset('./datasets/switch_devices', './datasets/switch_devices_yolo')
train.py
import torch
from ultralytics import YOLO

# 设置随机种子以保证可重复性
torch.manual_seed(42)

# 定义数据集路径
dataset_config = 'data.yaml'

# 加载预训练的YOLOv8n模型
model = YOLO('yolov8n.pt')

# 训练模型
results = model.train(
    data=dataset_config,
    epochs=50,
    imgsz=640,
    batch=16,
    name='switch_devices',
    project='runs/train'
)

# 评估模型
metrics = model.val()

# 保存最佳模型权重
best_model_weights = 'runs/train/switch_devices/weights/best.pt'
print(f"最佳模型权重已保存到 {best_model_weights}")
evaluate.py
from ultralytics import YOLO

# 初始化YOLOv8模型
model = YOLO('runs/train/switch_devices/weights/best.pt')

# 评估模型
metrics = model.val()

# 打印评估结果
print(metrics)
infer.py
import sys
import cv2
import numpy as np
from ultralytics import YOLO
from PyQt5.QtWidgets import QApplication, QMainWindow, QFileDialog, QMessageBox, QLabel, QPushButton
from PyQt5.QtGui import QImage, QPixmap
from PyQt5.QtCore import QTimer

class MainWindow(QMainWindow):
    def __init__(self):
        super(MainWindow, self).__init__()
        self.setWindowTitle("开关设备红外过热检测")
        self.setGeometry(100, 100, 800, 600)
        
        # 初始化YOLOv8模型
        self.model = YOLO('runs/train/switch_devices/weights/best.pt')
        
        # 设置类别名称
        self.class_names = [
            'core', 'connection_part', 'body', 'load_switch', 
            'arrester', 'current_transformer', 'voltage_transformer', 'plastic_enclosure_circuit_breaker'
        ]
        
        # 创建界面元素
        self.label_display = QLabel(self)
        self.label_display.setGeometry(10, 10, 780, 400)
        
        self.button_select_image = QPushButton("选择图片", self)
        self.button_select_image.setGeometry(10, 420, 150, 30)
        self.button_select_image.clicked.connect(self.select_image)
        
        self.button_select_video = QPushButton("选择视频", self)
        self.button_select_video.setGeometry(170, 420, 150, 30)
        self.button_select_video.clicked.connect(self.select_video)
        
        self.button_start_camera = QPushButton("开始摄像头", self)
        self.button_start_camera.setGeometry(330, 420, 150, 30)
        self.button_start_camera.clicked.connect(self.start_camera)
        
        self.button_stop_camera = QPushButton("停止摄像头", self)
        self.button_stop_camera.setGeometry(490, 420, 150, 30)
        self.button_stop_camera.clicked.connect(self.stop_camera)
        
        self.timer = QTimer()
        self.timer.timeout.connect(self.update_frame)
        
        self.cap = None
        self.results = []

    def select_image(self):
        options = QFileDialog.Options()
        file_path, _ = QFileDialog.getOpenFileName(self, "选择图片", "", "图片 (*.jpg *.jpeg *.png);;所有文件 (*)", options=options)
        if file_path:
            self.process_image(file_path)

    def process_image(self, image_path):
        frame = cv2.imread(image_path)
        results = self.model(frame)
        annotated_frame = self.draw_annotations(frame, results)
        self.display_image(annotated_frame)
        self.results.append((image_path, annotated_frame))

    def select_video(self):
        options = QFileDialog.Options()
        file_path, _ = QFileDialog.getOpenFileName(self, "选择视频", "", "视频 (*.mp4 *.avi);;所有文件 (*)", options=options)
        if file_path:
            self.process_video(file_path)

    def process_video(self, video_path):
        self.cap = cv2.VideoCapture(video_path)
        while self.cap.isOpened():
            ret, frame = self.cap.read()
            if not ret:
                break
            results = self.model(frame)
            annotated_frame = self.draw_annotations(frame, results)
            self.display_image(annotated_frame)
            self.results.append((video_path, annotated_frame))
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
        self.cap.release()

    def start_camera(self):
        self.cap = cv2.VideoCapture(0)
        self.timer.start(30)

    def stop_camera(self):
        self.timer.stop()
        if self.cap is not None:
            self.cap.release()
            self.label_display.clear()

    def update_frame(self):
        ret, frame = self.cap.read()
        if not ret:
            return
        results = self.model(frame)
        annotated_frame = self.draw_annotations(frame, results)
        self.display_image(annotated_frame)
        self.results.append(('camera', annotated_frame))

    def draw_annotations(self, frame, results):
        for result in results:
            boxes = result.boxes.cpu().numpy()
            for box in boxes:
                r = box.xyxy[0].astype(int)
                cls = int(box.cls[0])
                conf = box.conf[0]
                
                label = f"{self.class_names[cls]} {conf:.2f}"
                color = (0, 255, 0)
                cv2.rectangle(frame, (r[0], r[1]), (r[2], r[3]), color, 2)
                cv2.putText(frame, label, (r[0], r[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
        return frame

    def display_image(self, frame):
        rgb_image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        h, w, ch = rgb_image.shape
        bytes_per_line = ch * w
        qt_image = QImage(rgb_image.data, w, h, bytes_per_line, QImage.Format_RGB888)
        pixmap = QPixmap.fromImage(qt_image)
        self.label_display.setPixmap(pixmap.scaled(self.label_display.width(), self.label_display.height()))

if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = MainWindow()
    window.show()
    sys.exit(app.exec_())

运行步骤总结

  1. 克隆项目仓库(如果有的话)

    git clone https://github.com/yourusername/switch_device_overheat_detection.git
    cd switch_device_overheat_detection
    
  2. 安装依赖项

    pip install -r requirements.txt
    
  3. 转换数据集格式

    python convert_voc_to_yolo.py
    
  4. 训练模型

    python train.py
    
  5. 评估模型

    python evaluate.py
    
  6. 运行推理界面

    python infer.py
    

操作界面

  • 选择图片进行检测。
  • 选择视频进行检测。
  • 使用摄像头进行实时检测。
  • 结果展示

你可以通过以下方式查看演示视频:

  • 用上述步骤运行 infer.py 并按照界面上的按钮操作。

希望这些详细的信息和代码能够帮助你顺利实施和优化你的开关设备红外过热检测系统。如果有其他需求或问题,请随时告知!

详细解释

requirements.txt

列出项目所需的所有Python包及其版本。

data.yaml

配置数据集路径和类别信息,用于YOLOv8模型训练。

convert_voc_to_yolo.py

将VOC格式的数据集转换为YOLO格式。读取XML标注文件并将其转换为YOLO所需的TXT标签格式。同时,将数据集分为训练集和验证集。

train.py

加载预训练的YOLOv8模型并使用自定义数据集进行训练。训练完成后评估模型并保存最佳模型权重。

evaluate.py

加载训练好的YOLOv8模型并对验证集进行评估,打印评估结果。

infer.py

创建一个GUI应用程序,支持选择图片、视频或使用摄像头进行实时检测,并显示检测结果。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值