如何训练——电气类设备——绝缘子串检测,绝缘子串闪络污染检测,绝缘子串破损检测,共1600张图像,支持yolo,voc,json格式标注,2GB数据量
电气类设备——绝缘子串检测,绝缘子串闪络污染检测,绝缘子串破损检测,共1600张图像,支持yolo,voc,json格式标注,2GB数据量
实现绝缘子串检测、闪络污染检测和破损检测任务,我们可以使用YOLOv5模型来进行目标检测。以下是详细的步骤和代码示例,包括数据集定义、配置文件、训练脚本等。
目录结构
首先,确保你的项目目录结构如下:
/insulator_detection_project
/datasets
/train
/images
*.jpg
/labels
*.txt
/valid
/images
*.jpg
/labels
*.txt
/scripts
train.py
datasets.py
config.yaml
requirements.txt
config.yaml
配置文件 config.yaml
包含训练参数、数据路径等信息。
# config.yaml
train: ../datasets/train/images/
val: ../datasets/valid/images/
nc: 3
names: ['flashover_pollution', 'damage', 'insulator']
requirements.txt
列出所有需要安装的Python包。
torch>=1.8
torchvision>=0.9
pycocotools
opencv-python
matplotlib
albumentations
labelme2coco
datasets.py
定义数据集类以便于加载绝缘子串检测的数据集,并进行数据增强。
import os
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
class InsulatorDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = Path(root_dir)
self.transform = transform
self.img_files = list((self.root_dir / 'images').glob('*.jpg'))
self.label_files = [Path(str(img_file).replace('images', 'labels').replace('.jpg', '.txt')) for img_file in self.img_files]
def __len__(self):
return len(self.img_files)
def __getitem__(self, idx):
img_path = self.img_files[idx]
label_path = self.label_files[idx]
image = Image.open(img_path).convert("RGB")
boxes = []
labels = []
with open(label_path, 'r') as file:
lines = file.readlines()
for line in lines:
class_id, x_center, y_center, width, height = map(float, line.strip().split())
boxes.append([x_center, y_center, width, height])
labels.append(int(class_id))
if self.transform:
transformed = self.transform(image=np.array(image), bboxes=boxes, class_labels=labels)
image = transformed['image']
boxes = transformed['bboxes']
labels = transformed['class_labels']
target = {}
target['boxes'] = torch.tensor(boxes, dtype=torch.float32)
target['labels'] = torch.tensor(labels, dtype=torch.int64)
return image, target
# 定义数据增强
data_transforms = {
'train': A.Compose([
A.Resize(width=640, height=640),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Rotate(limit=180, p=0.7),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
], bbox_params=A.BboxParams(format='yolo')),
'test': A.Compose([
A.Resize(width=640, height=640),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
], bbox_params=A.BboxParams(format='yolo')),
}
train.py
编写训练脚本来训练YOLOv5模型。
import torch
import torch.optim as optim
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from datasets import InsulatorDataset, data_transforms
from torch.utils.data import DataLoader
import yaml
import time
import datetime
from collections import defaultdict
from collections import deque
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
with open('config.yaml', 'r') as f:
config = yaml.safe_load(f)
def collate_fn(batch):
images = [item[0] for item in batch]
targets = [item[1] for item in batch]
images = torch.stack(images)
return images, targets
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
model.train()
metric_logger = MetricLogger(delimiter=" ")
header = f"Epoch: [{epoch}]"
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
optimizer.zero_grad()
losses.backward()
optimizer.step()
metric_logger.update(loss=losses.item(), **loss_dict)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(f"'MetricLogger' object has no attribute '{attr}'")
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
eta_string = SmoothedValue(fmt='{eta}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string.update(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self),
time=str(iter_time), memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self),
time=str(iter_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def main():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dataset_train = InsulatorDataset(root_dir=config['train'], transform=data_transforms['train'])
dataset_val = InsulatorDataset(root_dir=config['val'], transform=data_transforms['test'])
data_loader_train = DataLoader(dataset_train, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)
data_loader_val = DataLoader(dataset_val, batch_size=4, shuffle=False, num_workers=4, collate_fn=collate_fn)
model = fasterrcnn_resnet50_fpn_v2(pretrained=True)
num_classes = config['nc'] + 1 # background + number of classes
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torch.nn.Linear(in_features, num_classes)
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
for epoch in range(10): # number of epochs
train_one_epoch(model, optimizer, data_loader_train, device=device, epoch=epoch, print_freq=10)
# save every epoch
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, f'model_epoch_{epoch}.pth')
if __name__ == "__main__":
main()
总结
以上代码涵盖了从数据准备到模型训练的所有步骤。你可以根据需要调整配置文件中的参数,并运行训练脚本来开始训练Fast R-CNN模型。确保你的数据集目录结构符合预期,并且所有的文件路径都是正确的。