基于 电力系统红外分割数据集_变电站分割数据集
从配置数据集、加载预训练模型、训练模型、评估模型到加载训练好的模型进行推理
文章目录
以下文字及代码仅供参考。
文章目录
1. 数据准备
1.1 数据集结构
jiashe假设的数据集目录结构如下:
dataset/
├── images/
│ ├── train/
│ │ ├── img1.jpg
│ │ └── ...
│ └── val/
│ ├── img1.jpg
│ └── ...
└── annotations/
├── train/
│ ├── img1.json
│ └── ...
└── val/
├── img1.json
└── ...
每个 JSON 文件包含图像的标注信息(如多边形分割掩码)。类别标签为:
- 避雷器 (Lightning Arrester)
- 绝缘子 (Insulator)
- 电流互感器 (Current Transformer)
- 套管 (Bushing)
- 电压互感器 (Voltage Transformer)
2. 安装依赖
确保安装了必要的库:
pip install torch torchvision opencv-python matplotlib pycocotools albumentations
3. 配置数据集
我们将使用 PyTorch 和 COCO 格式的标注文件。需要将 JSON 转换为 COCO 格式,或者直接使用自定义的数据加载器。
3.1 自定义数据加载器
import os
import json
import cv2
import numpy as np
from torch.utils.data import Dataset
class InfraredSegmentationDataset(Dataset):
def __init__(self, image_dir, annotation_dir, transform=None):
self.image_dir = image_dir
self.annotation_dir = annotation_dir
self.transform = transform
self.image_files = os.listdir(image_dir)
self.class_names = ["background", "Lightning Arrester", "Insulator", "Current Transformer", "Bushing", "Voltage Transformer"]
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image_name = self.image_files[idx]
image_path = os.path.join(self.image_dir, image_name)
annotation_path = os.path.join(self.annotation_dir, os.path.splitext(image_name)[0] + ".json")
# 读取图像
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 读取标注
with open(annotation_path, 'r') as f:
annotation = json.load(f)
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
for obj in annotation['objects']:
category_id = obj['category_id']
polygon = np.array(obj['polygon'], dtype=np.int32)
cv2.fillPoly(mask, [polygon], category_id)
if self.transform:
augmented = self.transform(image=image, mask=mask)
image, mask = augmented['image'], augmented['mask']
return image, mask
4. 模型选择与训练
我们选择 DeepLabV3+ 模型进行分割任务。
4.1 加载预训练模型
import torch
import torch.nn as nn
import torchvision.models.segmentation as segmentation_models
def load_model(num_classes):
model = segmentation_models.deeplabv3_resnet50(pretrained=True)
model.classifier[-1] = nn.Conv2d(256, num_classes, kernel_size=1)
return model
4.2 训练代码
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
def train_model(model, train_loader, val_loader, num_epochs=20, lr=0.001, device='cuda'):
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
model.to(device)
best_loss = float('inf')
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
for images, masks in tqdm(train_loader):
images, masks = images.to(device), masks.to(device)
optimizer.zero_grad()
outputs = model(images)['out']
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
train_loss += loss.item()
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss/len(train_loader):.4f}")
# 验证模型
val_loss = evaluate_model(model, val_loader, criterion, device)
print(f"Validation Loss: {val_loss:.4f}")
# 保存最佳模型
if val_loss < best_loss:
best_loss = val_loss
torch.save(model.state_dict(), "best_model.pth")
print("Best model saved!")
def evaluate_model(model, val_loader, criterion, device):
model.eval()
total_loss = 0.0
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device)
outputs = model(images)['out']
loss = criterion(outputs, masks)
total_loss += loss.item()
return total_loss / len(val_loader)
5. 推理与可视化
5.1 加载训练好的模型
def load_trained_model(num_classes, model_path="best_model.pth"):
model = load_model(num_classes)
model.load_state_dict(torch.load(model_path))
model.eval()
return model
5.2 推理与显示结果
import matplotlib.pyplot as plt
def visualize_segmentation(image, mask, predicted_mask):
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(image)
axes[0].set_title("Original Image")
axes[0].axis("off")
axes[1].imshow(mask, cmap="jet")
axes[1].set_title("Ground Truth Mask")
axes[1].axis("off")
axes[2].imshow(predicted_mask, cmap="jet")
axes[2].set_title("Predicted Mask")
axes[2].axis("off")
plt.show()
def inference(model, image_path, transform=None, device='cuda'):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if transform:
augmented = transform(image=image)
image = augmented['image']
image_tensor = torch.unsqueeze(image, 0).to(device)
model.to(device)
model.eval()
with torch.no_grad():
output = model(image_tensor)['out'][0]
predicted_mask = torch.argmax(output, dim=0).cpu().numpy()
visualize_segmentation(image, None, predicted_mask)
6. 主程序
from albumentations import Compose, Resize, Normalize
if __name__ == "__main__":
# 数据增强
transform = Compose([
Resize(512, 512),
Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
# 数据集与数据加载器
train_dataset = InfraredSegmentationDataset(
image_dir="dataset/images/train",
annotation_dir="dataset/annotations/train",
transform=transform
)
val_dataset = InfraredSegmentationDataset(
image_dir="dataset/images/val",
annotation_dir="dataset/annotations/val",
transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
# 加载模型并训练
model = load_model(num_classes=6) # 5 类别 + 背景
train_model(model, train_loader, val_loader, num_epochs=20)
# 推理
trained_model = load_trained_model(num_classes=6)
inference(trained_model, "dataset/images/test/sample.jpg", transform=transform)
7. 功能总结
- 数据预处理:将红外分割数据集转换为 COCO 格式,或使用自定义数据加载器。
- 模型训练:使用 DeepLabV3+ 进行训练,并保存最佳模型。
- 推理与可视化:加载训练好的模型,对测试图像进行分割并显示结果。