如何训练这套数据集呢?
开关设备红外过热图像数据集,总共5500左右张图片,标注为voc(xml)格式,总共8类,分别为核心,连接部分,主体,负荷开关,避雷器,电流互感器,电压互感器,塑料外壳式断路器
开关设备红外过热图像数据集,总共5500左右张图片,标注为voc(xml)格式,
总共8类,
分别为核心,连接部分,主体,负荷开关,避雷器,电流互感器,电压互感器,塑料外壳式断路器
完整项目结构
switch_device_overheat_detection/
├── main.py
├── train.py
├── evaluate.py
├── infer.py
├── datasets/
│ ├── switch_devices/
│ │ ├── Annotations/
│ │ ├── ImageSets/
│ │ │ └── Main/
│ │ │ ├── train.txt
│ │ │ └── val.txt
│ │ └── JPEGImages/
├── best_switch_device.pt
├── requirements.txt
└── data.yaml
文件内容
requirements.txt
opencv-python
torch==1.9
ultralytics
PyQt5
data.yaml
train: ./datasets/switch_devices/JPEGImages/train
val: ./datasets/switch_devices/JPEGImages/val
test: ./datasets/switch_devices/JPEGImages/test
nc: 8
names: ['core', 'connection_part', 'body', 'load_switch', 'arrester', 'current_transformer', 'voltage_transformer', 'plastic_enclosure_circuit_breaker']
convert_voc_to_yolo.py
import os
import xml.etree.ElementTree as ET
import shutil
import cv2
def xml_to_yolo(xml_file, image_width, image_height):
yolo_labels = []
tree = ET.parse(xml_file)
root = tree.getroot()
for obj in root.findall('object'):
label = obj.find('name').text
bbox = obj.find('bndbox')
xmin = int(bbox.find('xmin').text)
ymin = int(bbox.find('ymin').text)
xmax = int(bbox.find('xmax').text)
ymax = int(bbox.find('ymax').text)
x_center = (xmin + xmax) / 2.0 / image_width
y_center = (ymin + ymax) / 2.0 / image_height
width = (xmax - xmin) / image_width
height = (ymax - ymin) / image_height
class_id = {
'core': 0,
'connection_part': 1,
'body': 2,
'load_switch': 3,
'arrester': 4,
'current_transformer': 5,
'voltage_transformer': 6,
'plastic_enclosure_circuit_breaker': 7
}[label]
yolo_labels.append(f"{class_id} {x_center} {y_center} {width} {height}")
return '\n'.join(yolo_labels)
def split_dataset(image_dir, annotations_dir, output_dir, train_ratio=0.8):
images = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
num_train = int(len(images) * train_ratio)
train_images = images[:num_train]
val_images = images[num_train:]
with open(os.path.join(output_dir, 'ImageSets/Main/train.txt'), 'w') as f:
f.write('\n'.join([os.path.splitext(img)[0] for img in train_images]))
with open(os.path.join(output_dir, 'ImageSets/Main/val.txt'), 'w') as f:
f.write('\n'.join([os.path.splitext(img)[0] for img in val_images]))
def convert_dataset(voc_dir, yolo_dir):
annotations_dir = os.path.join(voc_dir, 'Annotations')
images_dir = os.path.join(voc_dir, 'JPEGImages')
yolo_labels_dir = os.path.join(yolo_dir, 'labels')
os.makedirs(yolo_labels_dir, exist_ok=True)
os.makedirs(os.path.join(yolo_dir, 'images'), exist_ok=True)
os.makedirs(os.path.join(yolo_dir, 'images/train'), exist_ok=True)
os.makedirs(os.path.join(yolo_dir, 'images/val'), exist_ok=True)
os.makedirs(os.path.join(yolo_dir, 'ImageSets/Main'), exist_ok=True)
split_dataset(images_dir, annotations_dir, yolo_dir)
for filename in os.listdir(annotations_dir):
if filename.endswith('.xml'):
xml_file = os.path.join(annotations_dir, filename)
image_filename = os.path.splitext(filename)[0] + '.jpg'
image_path = os.path.join(images_dir, image_filename)
image = cv2.imread(image_path)
image_height, image_width, _ = image.shape
yolo_label = xml_to_yolo(xml_file, image_width, image_height)
txt_filename = os.path.splitext(filename)[0] + '.txt'
txt_file = os.path.join(yolo_labels_dir, txt_filename)
with open(txt_file, 'w') as f:
f.write(yolo_label)
# Copy image to YOLO directory
base_image_dir = os.path.join(yolo_dir, 'images')
if image_filename.split('.')[0] in [line.strip() for line in open(os.path.join(yolo_dir, 'ImageSets/Main/train.txt'))]:
target_image_dir = os.path.join(base_image_dir, 'train')
else:
target_image_dir = os.path.join(base_image_dir, 'val')
shutil.copy(image_path, target_image_dir)
# 示例用法
convert_dataset('./datasets/switch_devices', './datasets/switch_devices_yolo')
train.py
import torch
from ultralytics import YOLO
# 设置随机种子以保证可重复性
torch.manual_seed(42)
# 定义数据集路径
dataset_config = 'data.yaml'
# 加载预训练的YOLOv8n模型
model = YOLO('yolov8n.pt')
# 训练模型
results = model.train(
data=dataset_config,
epochs=50,
imgsz=640,
batch=16,
name='switch_devices',
project='runs/train'
)
# 评估模型
metrics = model.val()
# 保存最佳模型权重
best_model_weights = 'runs/train/switch_devices/weights/best.pt'
print(f"最佳模型权重已保存到 {best_model_weights}")
evaluate.py
from ultralytics import YOLO
# 初始化YOLOv8模型
model = YOLO('runs/train/switch_devices/weights/best.pt')
# 评估模型
metrics = model.val()
# 打印评估结果
print(metrics)
infer.py
import sys
import cv2
import numpy as np
from ultralytics import YOLO
from PyQt5.QtWidgets import QApplication, QMainWindow, QFileDialog, QMessageBox, QLabel, QPushButton
from PyQt5.QtGui import QImage, QPixmap
from PyQt5.QtCore import QTimer
class MainWindow(QMainWindow):
def __init__(self):
super(MainWindow, self).__init__()
self.setWindowTitle("开关设备红外过热检测")
self.setGeometry(100, 100, 800, 600)
# 初始化YOLOv8模型
self.model = YOLO('runs/train/switch_devices/weights/best.pt')
# 设置类别名称
self.class_names = [
'core', 'connection_part', 'body', 'load_switch',
'arrester', 'current_transformer', 'voltage_transformer', 'plastic_enclosure_circuit_breaker'
]
# 创建界面元素
self.label_display = QLabel(self)
self.label_display.setGeometry(10, 10, 780, 400)
self.button_select_image = QPushButton("选择图片", self)
self.button_select_image.setGeometry(10, 420, 150, 30)
self.button_select_image.clicked.connect(self.select_image)
self.button_select_video = QPushButton("选择视频", self)
self.button_select_video.setGeometry(170, 420, 150, 30)
self.button_select_video.clicked.connect(self.select_video)
self.button_start_camera = QPushButton("开始摄像头", self)
self.button_start_camera.setGeometry(330, 420, 150, 30)
self.button_start_camera.clicked.connect(self.start_camera)
self.button_stop_camera = QPushButton("停止摄像头", self)
self.button_stop_camera.setGeometry(490, 420, 150, 30)
self.button_stop_camera.clicked.connect(self.stop_camera)
self.timer = QTimer()
self.timer.timeout.connect(self.update_frame)
self.cap = None
self.results = []
def select_image(self):
options = QFileDialog.Options()
file_path, _ = QFileDialog.getOpenFileName(self, "选择图片", "", "图片 (*.jpg *.jpeg *.png);;所有文件 (*)", options=options)
if file_path:
self.process_image(file_path)
def process_image(self, image_path):
frame = cv2.imread(image_path)
results = self.model(frame)
annotated_frame = self.draw_annotations(frame, results)
self.display_image(annotated_frame)
self.results.append((image_path, annotated_frame))
def select_video(self):
options = QFileDialog.Options()
file_path, _ = QFileDialog.getOpenFileName(self, "选择视频", "", "视频 (*.mp4 *.avi);;所有文件 (*)", options=options)
if file_path:
self.process_video(file_path)
def process_video(self, video_path):
self.cap = cv2.VideoCapture(video_path)
while self.cap.isOpened():
ret, frame = self.cap.read()
if not ret:
break
results = self.model(frame)
annotated_frame = self.draw_annotations(frame, results)
self.display_image(annotated_frame)
self.results.append((video_path, annotated_frame))
if cv2.waitKey(1) & 0xFF == ord('q'):
break
self.cap.release()
def start_camera(self):
self.cap = cv2.VideoCapture(0)
self.timer.start(30)
def stop_camera(self):
self.timer.stop()
if self.cap is not None:
self.cap.release()
self.label_display.clear()
def update_frame(self):
ret, frame = self.cap.read()
if not ret:
return
results = self.model(frame)
annotated_frame = self.draw_annotations(frame, results)
self.display_image(annotated_frame)
self.results.append(('camera', annotated_frame))
def draw_annotations(self, frame, results):
for result in results:
boxes = result.boxes.cpu().numpy()
for box in boxes:
r = box.xyxy[0].astype(int)
cls = int(box.cls[0])
conf = box.conf[0]
label = f"{self.class_names[cls]} {conf:.2f}"
color = (0, 255, 0)
cv2.rectangle(frame, (r[0], r[1]), (r[2], r[3]), color, 2)
cv2.putText(frame, label, (r[0], r[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
return frame
def display_image(self, frame):
rgb_image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
h, w, ch = rgb_image.shape
bytes_per_line = ch * w
qt_image = QImage(rgb_image.data, w, h, bytes_per_line, QImage.Format_RGB888)
pixmap = QPixmap.fromImage(qt_image)
self.label_display.setPixmap(pixmap.scaled(self.label_display.width(), self.label_display.height()))
if __name__ == "__main__":
app = QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec_())
运行步骤总结
-
克隆项目仓库(如果有的话):
git clone https://github.com/yourusername/switch_device_overheat_detection.git cd switch_device_overheat_detection
-
安装依赖项:
pip install -r requirements.txt
-
转换数据集格式:
python convert_voc_to_yolo.py
-
训练模型:
python train.py
-
评估模型:
python evaluate.py
-
运行推理界面:
python infer.py
操作界面
- 选择图片进行检测。
- 选择视频进行检测。
- 使用摄像头进行实时检测。
- 结果展示
你可以通过以下方式查看演示视频:
- 用上述步骤运行
infer.py
并按照界面上的按钮操作。
希望这些详细的信息和代码能够帮助你顺利实施和优化你的开关设备红外过热检测系统。如果有其他需求或问题,请随时告知!
详细解释
requirements.txt
列出项目所需的所有Python包及其版本。
data.yaml
配置数据集路径和类别信息,用于YOLOv8模型训练。
convert_voc_to_yolo.py
将VOC格式的数据集转换为YOLO格式。读取XML标注文件并将其转换为YOLO所需的TXT标签格式。同时,将数据集分为训练集和验证集。
train.py
加载预训练的YOLOv8模型并使用自定义数据集进行训练。训练完成后评估模型并保存最佳模型权重。
evaluate.py
加载训练好的YOLOv8模型并对验证集进行评估,打印评估结果。
infer.py
创建一个GUI应用程序,支持选择图片、视频或使用摄像头进行实时检测,并显示检测结果。