1. 数据处理
先处理数据,布匹瑕疵数据集有0-13个类别,第0类为无瑕疵。
本次训练的数据为所有的塞网瑕疵,是第2类。
根据train_json文件,把Sample文件夹中所有的图片中,挑出塞网瑕疵图;同时只需要塞网瑕疵的用于yolo的txt文件。
1. yolo用的txt,只需要类别2,先处理json文件,即后面{"bbox": [], "id": 1930, "image_id": 1883, "area": -1, "segmentation": [], "iscrowd": 0, "category_id": 0}这种格式的,只留下"category_id": 2,代码如下
import json
# Load the JSON data from the file
file_path = '/mnt/data/train.json'
with open(file_path, 'r') as file:
data = json.load(file)
# Filter out the objects with "category_id" not equal to 2
filtered_annotations = [obj for obj in data['annotations'] if obj['category_id'] == 2]
# Update the 'annotations' in the original data with filtered annotations
data['annotations'] = filtered_annotations
# Save the modified data back to a new file
filtered_file_path = '/mnt/data/train_filtered.json'
with open(filtered_file_path, 'w') as file:
json.dump(data, file, indent=4)
2. sample文件夹中有所有类型的瑕疵图,但是只需要第2类瑕疵图,代码如下,会生成一个新文件夹,里面只有第2类瑕疵图。
import json
import os
import shutil
# JSON文件路径
json_file_path = 'train.json' # 请替换为你的JSON文件路径
# 原图片文件夹路径
source_folder = 'sample'
# 目标文件夹路径
destination_folder = 'sampleCategory2'
# 确保目标文件夹存在
if not os.path.exists(destination_folder):
os.makedirs(destination_folder)
# 加载JSON数据
with open(json_file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
# 提取具有category_id为2的image_id
category_2_ids = {anno['image_id'] for anno in data['annotations'] if anno['category_id'] == 2}
# 过滤出对应的图片文件名
filtered_image_filenames = [img['file_name'] for img in data['images'] if img['id'] in category_2_ids]
# 复制筛选出的图片到目标文件夹
for filename in filtered_image_filenames:
source_path = os.path.join(source_folder, filename)
destination_path = os.path.join(destination_folder, filename)
if os.path.exists(source_path):
shutil.copy(source_path, destination_path)
print(f"图片 {filename} 已被复制到 {destination_folder}")
else:
print(f"未找到图片 {filename}, 请检查源文件夹。")
3. 根据json文件,生成txt文件给yolo用
# -*- coding: gbk -*-
import os
import json
# Load the JSON file
with open('train.json', 'r') as f:
data = json.load(f)
# Extract the annotations and categories
annotations = data['annotations']
categories = data['categories']
# Loop through each annotation and convert it to YOLO format
for annotation in annotations:
# Extract the image filename, width, and height
image_id = annotation['image_id']
image = next((x for x in data['images'] if x['id'] == image_id), None)
file_name = image['file_name']
width = 400
height = 400
# Extract the category name and ID
category_id = annotation['category_id']
category = next((x for x in categories if x['id'] == category_id), None)
class_name = category['name']
# Extract the bbox coordinates
bbox = annotation['bbox']
if len(bbox) == 0:
# If bbox is empty, it means there is no defect in the image
label = ''
elif len(bbox) == 4:
# If bbox has 4 elements, extract the coordinates
x_center = (bbox[0] + bbox[2] / 2) / width
y_center = (bbox[1] + bbox[3] / 2) / height
box_width = bbox[2] / width
box_height = bbox[3] / height
# Convert the coordinates and class name to YOLO format
label = f"{category_id} {x_center} {y_center} {box_width} {box_height}"
else:
# Otherwise, skip this annotation
print(f"Invalid bbox format: {bbox}")
continue
# Write the label to a text file in the 'label' folder
base_dir = os.path.dirname(os.path.abspath(__file__))
label_folder = os.path.join(base_dir, 'label')
os.makedirs(label_folder, exist_ok=True)
label_file = os.path.join(label_folder, file_name.replace('.jpg', '.txt'))
with open(label_file, 'w') as f:
f.write(label + '\n')
2. yolov8环境配置
在mistgpu上跑
先用conda创建虚拟环境,避免包与包之间的冲突。环境命名为yolov8,python版本指定为3.9。
conda create -n yolov8 python=3.9
conda activate yolov8
pip install ultralytics
然后修改配置文件,路径为:./ultralytics/cfg/datasets/mydata.yaml,文件内容为
# Ultralytics YOLO 🚀, AGPL-3.0 license
# COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics
# Example usage: yolo train data=coco128.yaml
# parent
# ├── ultralytics
# └── datasets
# └── coco128 ← downloads here (7 MB)
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: ../newDataset # dataset root dir
train: images/train # train images (relative to 'path') 128 images
val: images/val # val images (relative to 'path') 128 images
test: # test images (optional)
names:
0: Flawless
1: Sewing heads
2: Plugged mesh
3: Anterior and backward chromatic aberration
4: False color
5: Depulping
6: White strips
7: Stains
8: Decolorization
9: Dirt
10: Blurred prints
11: Holes
12: Water stains
13: Color crossing
3. 开始训练
新建一个demotrain.py文件用于测试
from ultralytics import YOLO
# 加载模型
model = YOLO('best224.pt') # 加载预训练模型(推荐用于训练)
# 使用macos训练模型
results = model.train(data='mydata224.yaml', epochs=200, imgsz=224, device='mps')
模型默认的训练参数,包括训练多少个epoch,多少个epoch损失没下降就不训练了,在./ultralytics/cfg/default.yaml。
只用塞网数据,大小是400*400,由于塞网瑕疵大多数在图片的正中间,就做一个数据增强:随机把图片裁剪为224*224,让瑕疵不要在正中间。
先跑了400*400的,再跑了224*224的,结果都一般。224的还有一堆数据corrupt,有的是面积为负,有的是框的边界点不在图里。然后学长说面积为负数的是无瑕疵,当时忘了置为0,反正他把处理好的数据发我了,等会再跑跑。
400*400的map50差不多到0.24,224*224的map50到0.21,处理好了的224*224的map50明天看看。
4. 测试
超详细||YOLOv8基础教程(环境搭建,训练,测试,部署看一篇就够)(在推理视频中添加FPS信息)_yolov8安装-CSDN博客
预测用predict.py,可以把参数iou修改为0.3,这样可以减少重叠的框。注意save_conf=False,要不然存下的label会多一个置信度,格式就不对了。
from ultralytics import YOLO
if __name__ == '__main__':
# Load a model
model = YOLO('best224.pt') # pretrained YOLOv8n model
model.predict(
source='/Users/lixiang/PycharmProjects/dataset224new/images/test',
save=True, # save predict results
imgsz=224, # (int) size of input images as integer or w,h
conf=0.25, # object confidence threshold for detection (default 0.25 predict, 0.001 val)
iou=0.6, # # intersection over union (IoU) threshold for NMS
show=True, # show results if possible
project='outnms06', # (str, optional) project name
name='', # (str, optional) experiment name, results saved to 'project/name' directory
save_txt=True, # save results as .txt file
save_conf=False, # save results with confidence scores
save_crop=False, # save cropped images with results
show_labels=True, # show object labels in plots
show_conf=True, # show object confidence scores in plots
vid_stride=1, # video frame-rate stride
line_width=1, # bounding box thickness (pixels)
visualize=False, # visualize model features
augment=False, # apply image augmentation to prediction sources
agnostic_nms=False, # class-agnostic NMS
retina_masks=False, # use high-resolution segmentation masks
boxes=True, # Show boxes in segmentation predictions
)
eval.py可以通过gt和预测的labels文件夹,计算precision等指标。
import os
import numpy as np
from sklearn.metrics import precision_recall_fscore_support
def parse_yolo_label(file_path):
"""解析YOLO格式的标签文件。"""
labels = []
try:
with open(file_path, 'r') as file:
lines = file.readlines()
labels = [list(map(float, line.strip().split())) for line in lines]
except FileNotFoundError:
pass # 如果文件不存在,返回空列表
return labels
def calculate_iou(box1, box2):
"""计算两个边界框的交并比(IoU)。"""
x1, y1, w1, h1 = box1
x2, y2, w2, h2 = box2
inter_w = max(0, min(x1 + w1 / 2, x2 + w2 / 2) - max(x1 - w1 / 2, x2 - w2 / 2))
inter_h = max(0, min(y1 + h1 / 2, y2 + h2 / 2) - max(y1 - h1 / 2, y2 - h2 / 2))
inter_area = inter_w * inter_h
union_area = w1 * h1 + w2 * h2 - inter_area
iou = inter_area / union_area if union_area > 0 else 0
return iou
def evaluate_predictions(gt_folder, pred_folder, iou_threshold=0.5):
"""评估预测的精度、召回率和F1分数。"""
gt_files = os.listdir(gt_folder)
tp = 0 # 真正例
fp = 0 # 假正例
fn = 0 # 假负例
for gt_file in gt_files:
gt_path = os.path.join(gt_folder, gt_file)
pred_path = os.path.join(pred_folder, gt_file)
gt_labels = parse_yolo_label(gt_path)
pred_labels = parse_yolo_label(pred_path)
matched = [False] * len(pred_labels)
for gt in gt_labels:
gt_matched = False
for i, pred in enumerate(pred_labels):
if calculate_iou(gt[1:], pred[1:]) >= iou_threshold:
if not matched[i]: # 防止多个真实标签匹配到同一个预测标签
matched[i] = True
gt_matched = True
tp += 1
break
if not gt_matched:
fn += 1
fp += matched.count(False) # 所有未匹配的预测都视为假正例
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
return precision, recall, f1
# 示例用法
gt_folder = '/Users/lixiang/PycharmProjects/dataset224new/labels/test'
pred_folder = '/Users/lixiang/PycharmProjects/ultralytics-main/outnms03/predict2/labels'
precision, recall, f1 = evaluate_predictions(gt_folder, pred_folder)
print(f'Precision: {precision}, Recall: {recall}, F1 Score: {f1}')