数据增强方式mosaic(基于yolo4)代码实现python

本文介绍了三种实现Mosaic数据增强的方法,适用于目标检测任务。Mosaic通过组合四张图片,增强训练数据,提升模型的泛化能力。第一种方法直接对图片进行放缩组合;第二种通过类定义实现,包括色域变换和随机翻转;第三种利用PyTorch的DataLoader。所有方法在1024x1024尺寸时可能出现标签对应不准确的问题,但在其他尺寸下效果良好。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

近几天在研究数据增强的方式,看到了mosaic,并且需要用到它,查阅了一些代码,根据个人数据,修改了代码,现有三套代码实现了mosaic,但也有些许问题。
mosaic数据增强方式是一次性从数据集中随机读取四张图片进行放缩或裁剪,然后将初步处理的四张图片按照左上、左下、右下、右上的顺序进行位置分布,组合后得到一张新的图,并将标签数据处理后进行对应。

mosaic的优势,一次性输入四张图,受batchsize的影响较小,数据增强方式,增加训练数据,组合后的图像拥有不同的背景,空间语义信息更加丰富,增强模型的泛化性能,这些都是可能带来的增益。

根据前后的学习经验,放出调试运行通过的三套代码,主要是网上找的含有mosaic功能的函数,自己添加数据加载方式进行调试调用可以得到输出结果,并有相应结果可视化展示,三种方法的区别主要是输入数据集的格式有区别,一次性读取四张图片的方式不太一样,具体细节见代码

第一种mosaic方式

四张图片只进行放缩后就组合,数据采用voc格式数据,最终返回合并后的图片和标注信息,并将结果做了可视化。

import random

import cv2
import os
import glob
import numpy as np
from PIL import Image
from lxml import etree
from ipdb import set_trace

OUTPUT_SIZE = (1024, 1024)  # Height, Width
SCALE_RANGE = (0.5, 0.5)
FILTER_TINY_SCALE = 1 / 50  # if height or width lower than this scale, drop it.


#voc格式的数据集,anno_dir是标注xml文件,img_dir是对应jpg图片
ANNO_DIR = './dataset/wind_turbine/600/Annotations'
IMG_DIR ='./dataset/wind_turbine/600/Images'

# category_name = ['background', 'person']


def main():
    img_paths, annos = get_dataset(ANNO_DIR, IMG_DIR)

    # set_trace()

    idxs = random.sample(range(len(annos)), 4)#从annos列表长度中随机取4个数
    # set_trace()

    new_image, new_annos = update_image_and_anno(img_paths, annos,
                                                 idxs,
                                                 OUTPUT_SIZE, SCALE_RANGE,
                                                 filter_scale=FILTER_TINY_SCALE)
    #更新获取新图和对应anno
    cv2.imwrite('./img/wind_output.jpg', new_image)
    #annos是
    for anno in new_annos:
        start_point = (int(anno[1] * OUTPUT_SIZE[1]), int(anno[2] * OUTPUT_SIZE[0]))#左上角点
        end_point = (int(anno[3] * OUTPUT_SIZE[1]), int(anno[4] * OUTPUT_SIZE[0]))#右下角点
        cv2.rectangle(new_image, start_point, end_point, (0, 255, 0), 1, cv2.LINE_AA)#每循环一次在合成图画一个矩形
    
    cv2.imwrite('./img/wind_output_box.jpg', new_image)

    new_image = cv2.cvtColor(new_image, cv2.COLOR_BGR2RGB)
    new_image = Image.fromarray(new_image.astype(np.uint8))
    # new_image.show()
    # cv2.imwrite('./img/wind_output111.jpg', new_image)

def update_image_and_anno(all_img_list, all_annos, idxs, output_size, scale_range, filter_scale=0.):
    output_img = np.zeros([output_size[0], output_size[1], 3], dtype=np.uint8)
    scale_x = scale_range[0] + random.random() * (scale_range[1] - scale_range[0])
    scale_y = scale_range[0] + random.random() * (scale_range[1] - scale_range[0])
    divid_point_x = int(scale_x * output_size[1])
    divid_point_y = int(scale_y * output_size[0])

    new_anno = []
    for i, idx in enumerate(idxs):
        set_trace()
        path = all_img_list[idx]
        img_annos = all_annos[idx]

        img = cv2.imread(path)
        if i == 0:  # top-left
            img = cv2.resize(img, (divid_point_x, divid_point_y))
            output_img[:divid_point_y, :divid_point_x, :] = img
            for bbox in img_annos:
                xmin = bbox[1] * scale_x
                ymin = bbox[2] * scale_y
                xmax = bbox[3] * scale_x
                ymax = bbox[4] * scale_y
                new_anno.append([bbox[0], xmin, ymin, xmax, ymax])

        elif i == 1:  # top-right
            img = cv2.resize(img, (output_size[1] - divid_point_x, divid_point_y))
            output_img[:divid_point_y, divid_point_x:output_size[1], :] = img
            for bbox in img_annos:
                xmin = scale_x + bbox[1] * (1 - scale_x)
                ymin = bbox[2] * scale_y
                xmax = scale_x + bbox[3] * (1 - scale_x)
                ymax = bbox[4] * scale_y
                new_anno.append([bbox[0], xmin, ymin, xmax, ymax])
        elif i == 2:  # bottom-left
            img = cv2.resize(img, (divid_point_x, output_size[0] - divid_point_y))
            output_img[divid_point_y:output_size[0], :divid_point_x, :] = img
            for bbox in img_annos:
                xmin = bbox[1] * scale_x
                ymin = scale_y + bbox[2] * (1 - scale_y)
                xmax = bbox[3] * scale_x
                ymax = scale_y + bbox[4] * (1 - scale_y)
                new_anno.append([bbox[0], xmin, ymin, xmax, ymax])
        else:  # bottom-right
            img = cv2.resize(img, (output_size[1] - divid_point_x, output_size[0] - divid_point_y))
            output_img[divid_point_y:output_size[0], divid_point_x:output_size[1], :] = img
            for bbox in img_annos:
                xmin = scale_x + bbox[1] * (1 - scale_x)
                ymin = scale_y + bbox[2] * (1 - scale_y)
                xmax = scale_x + bbox[3] * (1 - scale_x)
                ymax = scale_y + bbox[4] * (1 - scale_y)
                new_anno.append([bbox[0], xmin, ymin, xmax, ymax])

    return output_img, new_anno


def get_dataset(anno_dir, img_dir):
    # class_id = category_name.index('person')

    img_paths = []
    annos = []
    # for anno_file in glob.glob(os.path.join(anno_dir, '*.txt')):
    for anno_file in glob.glob(os.path.join(anno_dir, '*.xml')):

        # anno_id = anno_file.split('/')[-1].split('.')[0]
        anno_id = anno_file.split('/')[-1].split('x')[0]
        # set_trace()

        # with open(anno_file, 'r') as f:
        #     num_of_objs = int(f.readline())

            # set_trace()
        img_path = os.path.join(img_dir, f'{anno_id}jpg')
        print(img_path)

        img = cv2.imread(img_path)
        # set_trace()   
        img_height, img_width, _ = img.shape
        print(img.shape)
        del img

        boxes = []
        bnd_box = parseXmlFiles(anno_file)
        print(bnd_box)
        for bnd_id, box in enumerate(bnd_box):
            # set_trace()
            
            categories_id = box[0]
           
            xmin = max(int(box[1]), 0) / img_width
           
            ymin = max(int(box[2]), 0) / img_height

            xmax = min(int(box[3]), img_width) / img_width

            ymax = min(int(box[4]), img_height) / img_height

           
            boxes.append([categories_id, xmin, ymin, xmax, ymax])
            print(boxes)
          

            if not boxes:
                continue

        img_paths.append(img_path)
        annos.append(boxes)
    print("annos:所有对原图缩放后的坐标:",annos)
    print(img_paths)
    return img_paths, annos

def parseXmlFiles(anno_dir):
    tree  = etree.parse(anno_dir)
    root = tree.getroot() 
    objectes = root.findall('.//object')
    bnd_box = []
    for object in objectes:
        name = object.find("name").text
        
        bndbox = object.find("bndbox")
        xmin = float(bndbox.find("xmin").text)
        xmax = float(bndbox.find("xmax").text)
        ymin = float(bndbox.find("ymin").text)
        ymax = float(bndbox.find("ymax").text)
        
        # bnd_box.append([name, xmin, xmax, ymin, ymax])
        bnd_box.append([name, xmin, ymin, xmax, ymax])
        # print(len(bnd_box),bnd_box)
    return bnd_box


if __name__ == '__main__':
    main()

效果如下:
合并后的图,将标签画在合并后的图上的效果
请添加图片描述
请添加图片描述

第二种mosaic方式:

数据读取函数稍作修改,通过类定义调取mosaic,最终输出合成图和对应的标签文件

# -*- coding: utf-8 -*-
import random
import os
import cv2
import glob
from lxml import etree
from PIL import Image
import numpy as np
from ipdb import set_trace


def merge_bboxes(bboxes, cutx, cuty):
    merge_bbox = []
    for i in range(len(bboxes)):
        for box in bboxes[i]:
            tmp_box = []
            x1, y1, x2, y2 = box[0], box[1], box[2], box[3]

            if i == 0:
                if y1 > cuty or x1 > cutx:
                    continue
                if y2 >= cuty and y1 <= cuty:
                    y2 = cuty
                    if y2 - y1 < 5:
                        continue
                if x2 >= cutx and x1 <= cutx:
                    x2 = cutx
                    if x2 - x1 < 5:
                        continue

            if i == 1:
                if y2 < cuty or x1 > cutx:
                    continue

                if y2 >= cuty and y1 <= cuty:
                    y1 = cuty
                    if y2 - y1 < 5:
                        continue

                if x2 >= cutx and x1 <= cutx:
                    x2 = cutx
                    if x2 - x1 < 5:
                        continue

            if i == 2:
                if y2 < cuty or x2 < cutx:
                    continue

                if y2 >= cuty and y1 <= cuty:
                    y1 = cuty
                    if y2 - y1 < 5:
                        continue

                if x2 >= cutx and x1 <= cutx:
                    x1 = cutx
                    if x2 - x1 < 5:
                        continue

            if i == 3:
                if y1 > cuty or x2 < cutx:
                    continue

                if y2 >= cuty and y1 <= cuty:
                    y2 = cuty
                    if y2 - y1 < 5:
                        continue

                if x2 >= cutx and x1 <= cutx:
                    x1 = cutx
                    if x2 - x1 < 5:
                        continue

            tmp_box.append(x1)
            tmp_box.append(y1)
            tmp_box.append(x2)
            tmp_box.append(y2)
            tmp_box.append(box[-1])
            merge_bbox.append(tmp_box)
    return merge_bbox

class Data_augmentation_with_Mosaic():
    def __init__(self,
                 four_annotation_lines,
                 input_shape,
                 max_boxes=100,
                 hue=.1,
                 sat=1.5,
                 val=1.5,
                 visual=True):

        self.four_annotation_lines = four_annotation_lines

        self.h, self.w = input_shape
        self.min_offset_x = 0.4
        self.min_offset_y = 0.4
        self.scale_low = 1 - min(self.min_offset_x, self.min_offset_y)- 0.3# 改动
        self.scale_high = self.scale_low + 0.2 + 0.3# 改动
        # print(self.scale_low, self.scale_high)


        self.max_boxes = max_boxes
        self.hue = hue
        self.sat = sat
        self.val = val

        self.place_x = [0, 0, int(self.w * self.min_offset_x), int(self.w * self.min_offset_x)]
        self.place_y = [0, int(self.h * self.min_offset_y), int(self.h * self.min_offset_y), 0]

        self.visual = visual


    def main(self):
        image_datas = []
        box_datas = []
        index = 0

        for annotation_line in self.four_annotation_lines:
            # 每一行进行分割
            # line_content = annotation_line.split()
            # 打开图片
            # image = Image.open(line_content[0])
            image = Image.open(annotation_line[-1])
            image = image.convert("RGB")
            # 图片的大小
            iw, ih = image.size
            # 保存框的位置
            # box = np.array([np.array(list(map(int, box.split(',')))) for box in line_content[1:]])
            box =  np.array([np.array(elm) for elm in annotation_line[:-1]])

            # 是否翻转图片
            flip = np.random.rand() < .5
            if flip and len(box) > 0:
                image = image.transpose(Image.FLIP_LEFT_RIGHT)
                box[:, [0, 2]] = iw - box[:, [2, 0]]

            # 对输入进来的图片进行缩放
            new_ar = self.w / self.h
            scale = np.random.uniform(self.scale_low, self.scale_high)
            if new_ar < 1:
                nh = int(scale * self.h)
                nw = int(nh * new_ar)
            else:
                nw = int(scale * self.w)
                nh = int(nw / new_ar)
            image = image.resize((nw, nh), Image.BICUBIC)

            # 进行色域变换
            hue = np.random.uniform(-self.hue, self.hue)
            sat = np.random.uniform(1, self.sat) if np.random.uniform() < .5 else 1 / np.random.uniform(1, self.sat)
            val = np.random.uniform(1, self.val) if np.random.uniform() < .5 else 1 / np.random.uniform(1, self.val)
            x = cv2.cvtColor(np.array(image, np.float32) / 255, cv2.COLOR_RGB2HSV)
            x[..., 0] += hue * 360
            x[..., 0][x[..., 0] > 1] -= 1
            x[..., 0][x[..., 0] < 0] += 1
            x[..., 1] *= sat
            x[..., 2] *= val
            x[x[:, :, 0] > 360, 0] = 360
            x[:, :, 1:][x[:, :, 1:] > 1] = 1
            x[x < 0] = 0
            image = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)  # numpy array, 0 to 1

            image = Image.fromarray((image * 255).astype(np.uint8))
            # 将图片进行放置,分别对应四张分割图片的位置
            dx = self.place_x[index]
            dy = self.place_y[index]


            new_image = Image.new('RGB', (self.w, self.h), (128, 128, 128))
            new_image.paste(image, (dx, dy))
            image_data = np.array(new_image) / 255

            if self.visual:
                image_ = np.array(image_data * 255., dtype=np.uint8)
                # print(np.shape(image_))
                image_ = cv2.cvtColor(image_, cv2.COLOR_RGB2BGR)
                cv2.imwrite( 'Image.jpg', image_)
                cv2.waitKey(100)

            index = index + 1
            box_data = []
            # 对box进行重新处理
            if len(box) > 0:
                np.random.shuffle(box)
                box[:, [0, 2]] = box[:, [0, 2]] * nw / iw + dx
                box[:, [1, 3]] = box[:, [1, 3]] * nh / ih + dy
                box[:, 0:2][box[:, 0:2] < 0] = 0
                box[:, 2][box[:, 2] > self.w] = self.w
                box[:, 3][box[:, 3] > self.h] = self.h
                box_w = box[:, 2] - box[:, 0]
                box_h = box[:, 3] - box[:, 1]
                box = box[np.logical_and(box_w > 1, box_h > 1)]
                box_data = np.zeros((len(box), 5))
                box_data[:len(box)] = box

            image_datas.append(image_data)
            box_datas.append(box_data)

        # 将图片分割,放在一起
        cutx = np.random.randint(int(self.w * self.min_offset_x), int(self.w * (1 - self.min_offset_x)))
        cuty = np.random.randint(int(self.h * self.min_offset_y), int(self.h * (1 - self.min_offset_y)))

        new_image = np.zeros([self.h, self.w, 3])
        cuty)
        new_image[:cuty, :cutx, :] = image_datas[0][:cuty, :cutx, :]
        new_image[cuty:, :cutx, :] = image_datas[1][cuty:, :cutx, :]
        new_image[cuty:, cutx:, :] = image_datas[2][cuty:, cutx:, :]
        new_image[:cuty, cutx:, :] = image_datas[3][:cuty, cutx:, :]

        # 对框进行进一步的处理
        new_boxes = merge_bboxes(box_datas, cutx, cuty)
        
        # 将box进行调整
        box_data = np.zeros((self.max_boxes, 5))
        if len(new_boxes) > 0:
            if len(new_boxes) > self.max_boxes: new_boxes = new_boxes[:self.max_boxes]
            box_data[:len(new_boxes)] = new_boxes

        if self.visual:

            # print(new_image.shape, np.max(new_image), np.min(new_image))
            new_image = np.array(new_image * 255., dtype=np.uint8)
            # print(new_image.shape, np.max(new_image), np.min(new_image))

            for box in box_data:
                box = [int(b) for b in box]
                cv2.rectangle(new_image, (box[0], box[1]), (box[2], box[3]), color=(255, 255, 255), thickness=1)

            new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
            cv2.imwrite('output_Image.jpg', new_image)
            cv2.waitKey(1000)

        return new_image, box_data

def get_dataset(img_dir,anno_dir):
    
    img_paths = []
    annos = []
    for anno_file in glob.glob(os.path.join(anno_dir, '*.xml')):
        anno_id = anno_file.split('/')[-1].split('x')[0]

        img_path = os.path.join(img_dir, f'{anno_id}jpg')

        img = cv2.imread(img_path)
        img_height, img_width, _ = img.shape
        del img

        tree  = etree.parse(anno_file)
        root = tree.getroot() 
        objectes = root.findall('.//object')
        bnd_box = []
        for object in objectes:
            name = object.find("name").text
            category_id = 1
            bndbox = object.find("bndbox")
            xmin = float(bndbox.find("xmin").text)
            xmax = float(bndbox.find("xmax").text)
            ymin = float(bndbox.find("ymin").text)
            ymax = float(bndbox.find("ymax").text)
            
            bnd_box.append([xmin, ymin, xmax, ymax, category_id])
       
        bnd_box.append(img_path)
        annos.append(bnd_box)
        
    return annos


if __name__ == "__main__":
    img_dir = '../../dataset/wind_turbine/Images'
    img_anno = '../../dataset/wind_turbine/Annotations'
    annotation_lines = get_dataset(img_dir, img_anno)
    # print(len(annotation_lines))

    four_annotation_lines = []
    for i, line in enumerate(annotation_lines):
        four_annotation_lines.append(line)
        if (i+1) % 4==0:
            set_trace()
            mosaic_aug = Data_augmentation_with_Mosaic(four_annotation_lines,
                                                         input_shape=(1024, 1024),
                                                         max_boxes=100,
                                                         hue=.1,
                                                         sat=1.5,
                                                         val=1.5,
                                                         visual=True)
            image_data, box_data = mosaic_aug.main()
            for box in box_data:
                box = [int(b) for b in box]
                cv2.rectangle(image_data, (box[0], box[1]), (box[2], box[3]), color=(255, 255, 255), thickness=1)

           
            cv2.imwrite(str(i)+'.jpg', image_data)
            four_annotation_lines = []

效果如下:
截取的区域并没有完全把输出大小图片填满,多余留成灰白
请添加图片描述
请添加图片描述
第三种mosaic方式:

from torch.utils.data import DataLoader
from PIL import Image
from torch.utils.data import Dataset
from torchvision import ops
import torchvision.transforms as tsf
import numpy as np
import torch
import random
from ipdb import set_trace
import tqdm
import cv2

def collator(batch):
    imgs = [img[0].unsqueeze(0) for img in batch]
    annos = [anno[1] for anno in batch]
    imgs = torch.cat(imgs)
    return imgs, annos
#batch调用数据,输出图片和标签,可尝试替换成其他方式或修改获取数据方式采用这种

class YOLODataset(Dataset):
    """
    yolo_line: img_path xmin,ymin,xmax,ymax,label xmin,ymin,...
    output_size: int for square output
    : probability
    color_jitter: transforms.ColorJitter(brightness, contrast, saturation, hue)
    """
    def __init__(
            self,
            yolo_line_list,
            output_size = (1024,1024),
            batch_size = 4,
            normalization:list = None,
            resize_mode = 'stretch',
            grayscale = False,
            size_augmentation:float = None,
            center_crop_ratio:tuple = None,
            random_flip = None,
            color_jitter = None,
            mosaic = False
        ):

        self.data = yolo_line_list
        self.size = output_size
        self.grayscale = grayscale
        self.crop_r = center_crop_ratio
        self.resize_mode = resize_mode
        self.size_aug = size_augmentation
        self.random_flip = random_flip
        self.color_jitter = color_jitter
        self.normalization = normalization
        self.mosaic = mosaic
        self.batch_size = batch_size
        
        self.counter = 0  # for counting a batch
        self.sign = None  # for size augmentation


    def decode_line(self, yolo_line):
        yolo_line = yolo_line.split()
        yolo_line[1:] = [s.split(',') for s in yolo_line[1:]]
        return yolo_line


    def normalize_bbox(self, tensor_bboxes, size):
        w, h = size
        tensor_bboxes[:,[0,2]] = tensor_bboxes[:,[0,2]] / w
        tensor_bboxes[:,[1,3]] = tensor_bboxes[:,[1,3]] / h
        return tensor_bboxes


    def center_crop(self, ratio:tuple, tensor_img, tensor_bboxes, rm_thres=16):
        """
        tensor_bboxes: [left, top, right, bottom]
        """
        _, h, w = tensor_img.size()
        r = ratio[0] / ratio[1]
        size = (int(min(w/r, h)), int(min(h*r, w)))  # (h, w)
        # img
        tensor_img = tsf.CenterCrop(size)(tensor_img)
        # align offset bbox
        x_offset, y_offset = (w-size[1])/2, (h-size[0])/2
        offsets = torch.Tensor([x_offset, y_offset, x_offset, y_offset])

        if len(tensor_bboxes):
            tensor_bboxes[:,[0,1,2,3]] = tensor_bboxes[:,[0,1,2,3]] - offsets
            # align bboxes out of the image
            tensor_bboxes[:,0][tensor_bboxes[:,0] < 0] = 0  # left side
            tensor_bboxes[:,1][tensor_bboxes[:,1] < 0] = 0  # top side
            tensor_bboxes[:,2][tensor_bboxes[:,2] > size[1]] = size[1]  # right side
            tensor_bboxes[:,3][tensor_bboxes[:,3] > size[0]] = size[0]  # bottom side
            # remove bboxes out of the image
            keep_x = tensor_bboxes[:,2] - tensor_bboxes[:,0] > rm_thres
            keep_y = tensor_bboxes[:,3] - tensor_bboxes[:,1] > rm_thres
            keep_mask = torch.logical_and(keep_x, keep_y)
            tensor_bboxes = tensor_bboxes[keep_mask]
        return tensor_img, tensor_bboxes


    def resize(self, size:tuple, tensor_img, tensor_bboxes, mode='stretch'):
        """
        size: tuple in (w, h)
        """
        assert mode in ['stretch', 'pad']
        _, h, w = tensor_img.size()

        if mode=='stretch' and len(tensor_bboxes):
            # bbox size and size align
            tensor_bboxes[:,[0,2]] = tensor_bboxes[:,[0,2]]*size[0] / w  # x resize
            tensor_bboxes[:,[1,3]] = tensor_bboxes[:,[1,3]]*size[1] / h  # y resize
        elif mode == 'pad':
            ratio = min(size[0]/w, size[1]/h)
            # pad img
            pads = (int(size[0]/ratio-w)//2, int(size[1]/ratio-h)//2)  # (w, h)
            tensor_img = tsf.Pad(pads)(tensor_img)
            
            if len(tensor_bboxes):
                # bbox center align
                tensor_bboxes[:,0] = tensor_bboxes[:,0] + pads[0]
                tensor_bboxes[:,1] = tensor_bboxes[:,1] + pads[1]
                
                # bbox size align
                tensor_bboxes[:,[0,2]] = tensor_bboxes[:,[0,2]]*ratio  # x resize
                tensor_bboxes[:,[1,3]] = tensor_bboxes[:,[1,3]]*ratio  # y resize

        # resize img
        tensor_img = tsf.Resize((size[1],size[0]))(tensor_img)  # (h, w)
        return tensor_img, tensor_bboxes
    
    
    def flip_randomly(self, tensor_img, tensor_bboxes, mode='all'):
        _, h, w = tensor_img.size()
        assert mode in ['all', 'horizontal', 'vertical']

        # horizontal 0.5
        if mode in ['all', 'horizontal']:
            if random.getrandbits(1):
                tensor_img = tsf.RandomHorizontalFlip(1)(tensor_img)
                # bbox center align
                if len(tensor_bboxes):
                    tensor_bboxes[:,0] = w - tensor_bboxes[:,0]
            
        # vertical 0.5
        if mode in ['all', 'vertical']:
            if random.getrandbits(1):
                tensor_img = tsf.RandomVerticalFlip(1)(tensor_img)
                # bbox center align
                if len(tensor_bboxes):
                    tensor_bboxes[:,1] = h - tensor_bboxes[:,1]

        return tensor_img, tensor_bboxes


    def get_mosaic(self, imgs:list, bboxes:list, rm_thres=16):
        """
        Every image size shold be same.
        imgs: list of 4 PIL images
        bboxes: list of 4 bbox list
        rm_thres: threshold for removing too small size bboxes
        """
        # get longest edges as size of new image
        new_size = imgs[0].size
        rs =  np.random.uniform(0.5, 1.5, [2])  # random shift
        center = (int(new_size[0]*rs[0]/2), int(new_size[1]*rs[1]/2))
        
        # crop each image
        imgs[1] = imgs[1].crop((center[0], 0, new_size[0], center[1]))
        imgs[2] = imgs[2].crop((0, center[1], center[0], new_size[1]))
        imgs[3] = imgs[3].crop((center[0], center[1], new_size[0], new_size[1]))
        
        # paste other image to main image
        imgs[0].paste(imgs[1], (center[0],0))
        imgs[0].paste(imgs[2], (0,center[1]))
        imgs[0].paste(imgs[3], (center[0],center[1]))

        # align bboxes
        b0 = np.array(bboxes[0])
        b1 = np.array(bboxes[1])
        b2 = np.array(bboxes[2])
        b3 = np.array(bboxes[3])

        box_list = list()
        # bbox 0
        if b0.any():
            b0[:,0][b0[:,0] > center[0]] = center[0]
            b0[:,2][b0[:,2] > center[0]] = center[0]
            b0[:,1][b0[:,1] > center[1]] = center[1]
            b0[:,3][b0[:,3] > center[1]] = center[1]
            box_list.append(b0)
        # bbox 1
        if b1.any():
            b1[:,0][b1[:,0] < center[0]] = center[0]
            b1[:,2][b1[:,2] < center[0]] = center[0]
            b1[:,1][b1[:,1] > center[1]] = center[1]
            b1[:,3][b1[:,3] > center[1]] = center[1]
            box_list.append(b1)
        # bbox 2
        if b2.any():
            b2[:,0][b2[:,0] > center[0]] = center[0]
            b2[:,2][b2[:,2] > center[0]] = center[0]
            b2[:,1][b2[:,1] < center[1]] = center[1]
            b2[:,3][b2[:,3] < center[1]] = center[1]
            box_list.append(b2)
        # bbox 3
        if b3.any():
            b3[:,0][b3[:,0] < center[0]] = center[0]
            b3[:,2][b3[:,2] < center[0]] = center[0]
            b3[:,1][b3[:,1] < center[1]] = center[1]
            b3[:,3][b3[:,3] < center[1]] = center[1]
            box_list.append(b3)

        if box_list:
            mosaic_bboxes = np.concatenate(box_list)
            # remove no area bbox
            keep_x = mosaic_bboxes[:,2] - mosaic_bboxes[:,0] > rm_thres
            keep_y = mosaic_bboxes[:,3] - mosaic_bboxes[:,1] > rm_thres
            keep_mask = np.logical_and(keep_x, keep_y)
            mosaic_bboxes = mosaic_bboxes[keep_mask]
        else:
            mosaic_bboxes = list()
        return imgs[0], mosaic_bboxes


    def __getitem__(self, index:int): 
        anno = self.decode_line(self.data[index])
        
        img_path = anno[0]  # read img
        img = Image.open(img_path).convert('RGB')  # make sure it is 3 channels
        
        bboxes = [list(map(int,map(float,b))) for b in anno[1:]]  # get bboxes
        
        # ------
        # Mosaic augmentation
        # ------
        
        if self.mosaic:
            # randomly get 3 others
            lines = random.sample(self.data, 3)  # get lines
            annos_list = list(map(self.decode_line, lines))  # decode lines
            # img1 + other imgs
            img_list = [img]
            bbox_list = [bboxes]
            for anno in annos_list:
                img_list.append(Image.open(anno[0]).convert('RGB'))  # img
                bbox_list.append([list(map(int,map(float,b))) for b in anno[1:]])
            
            # get mosaic from 4 imgs
            set_trace()
            img, bboxes = self.get_mosaic(img_list, bbox_list, rm_thres=16)

        # ------
        # To tensor, img grayscale, 
        # ------
        img = tsf.ToTensor()(img)  # to tensors
        if self.grayscale:
            tsf.Grayscale(3)
        
        bboxes = torch.Tensor(bboxes)  # to tensors
        
        # ------
        # Center Crop by aspect ratio
        # ------
        if self.crop_r is not None:
            img, bboxes = self.center_crop(self.crop_r, img, bboxes)

        # ------
        # change bbox format
        # from (left, top, right, bottom) to (x, y, w, h)
        # ------

        if len(bboxes):
            bboxes[:, :4] = ops.box_convert(bboxes[:,:4], in_fmt='xyxy', out_fmt='cxcywh')

        # ------
        # Resize an size augmentation
        # size augmentation will +- 32 depending the probability
        # ------
        if self.size_aug is not None:
            p = self.size_aug
            if self.counter%self.batch_size == 0:
                [self.sign] = random.choices([0,1,-1], [1-p,p/2,p/2])
            size = torch.Tensor(self.size).long() + self.sign*32
            self.counter += 1
        else:
            size = self.size
        img, bboxes = self.resize(size, img, bboxes, self.resize_mode)

        # ------
        # Flip and color jitter
        # ------
        if self.random_flip is not None:
            img, bboxes = self.flip_randomly(img, bboxes, self.random_flip)
        
        if self.color_jitter is not None:
            img = self.color_jitter(img)
        
        # ------
        # Normalization
        # ------
        if self.normalization is not None:
            img = tsf.Normalize(*self.normalization)(img)
        
        if len(bboxes):
            bboxes = self.normalize_bbox(bboxes, size)

        return img, bboxes


    def __len__(self):
        return len(self.data)

def decode_line(yolo_line):
    yolo_line = yolo_line.split()
    yolo_line[1:] = [s.split(',') for s in yolo_line[1:]]
    return yolo_line


if __name__ == "__main__":

    data_dict = "./mosaic_trainval_600.txt"
    with open(data_dict) as tf:
        yolo_lines = tf.readlines()

    np.random.shuffle(yolo_lines)    

    train_data= YOLODataset(
        yolo_lines,
        output_size = (1024, 1024),
        batch_size = 4,
        # normalization:list = None,
        resize_mode = 'stretch',
        grayscale = False,
        
        random_flip = None,
        color_jitter = None,
        mosaic = True
    )
   
    train_loader = DataLoader(
        train_data,
        batch_size = 4,
        shuffle = True,
        collate_fn = collator,
        num_workers = 0
    )

    lines = random.sample(yolo_lines,4)  # get lines
    annos_list = list(map(decode_line, lines))  # decode lines
    # img1 + other imgs
    img_list = []
    bbox_list = []
    for anno in annos_list:
        img_list.append(Image.open(anno[0]).convert('RGB'))  # img
        bbox_list.append([list(map(int,map(float,b))) for b in anno[1:]])
    
    
    image_data, box_data = train_data.get_mosaic(img_list, bbox_list, rm_thres=16)
    mosaic_img = np.array(image_data)
    for box in box_data:
        box = [int(b) for b in box]

        cv2.rectangle(mosaic_img, (box[0], box[1]), (box[2], box[3]), color=(0, 255, 255), thickness=1)

   
    cv2.imwrite('output_mosaic.jpg', mosaic_img)
    

数据集格式如图所示:
在这里插入图片描述
效果:
将设定输出大小的图片范围都用放缩或裁剪后的图片组合填满
请添加图片描述

问题:

对于以上三种方式,存在一个相同的问题,有点不明原因,当输入图片使用1024x1024大小时,生成的mosaic图片与标签对应不上,而使用其他尺寸的输入图片可以达到以上效果:
请添加图片描述

<think>嗯,用户让我介绍如何用Python代码实现YOLO算法来进行农产品品质分级。首先,我需要明确YOLO是什么,以及它在目标检测中的应用。YOLO(You Only Look Once)是一种实时目标检测算法,速度快且准确,适合处理图像中的多个物体检测。农产品品质分级可能涉及检测农产品的大小、形状、颜色或者表面缺陷,所以YOLO可以用于定位和分类这些特征。 接下来,用户可能需要了解实现的基本步骤。我需要分解整个过程,从数据准备到模型训练,再到推理应用。首先,数据收集和标注是关键。农产品图像需要标注出缺陷或特征,比如用工具如LabelImg生成YOLO格式的标签。然后,模型选择,可能使用现有的YOLO版本,比如YOLOv5或YOLOv8,因为它们有社区支持和预训练模型。 然后是数据预处理,包括调整图像大小、归一化、数据增强等,以提高模型泛化能力。模型训练部分需要加载预训练权重,设置超参数,进行训练并监控损失和指标。训练完成后,模型评估是必要的,使用验证集检查准确率、召回率等。最后,推理阶段,用训练好的模型对新图像进行预测,并根据检测结果进行分级。 用户可能需要代码示例,所以我要提供关键的代码片段,比如数据配置、模型加载、训练循环和推理代码。同时要注意数据路径的正确设置和依赖库的安装,比如PyTorch和Ultralytics库。可能还需要提到如何根据检测结果(如边界框的类别和置信度)来判断品质等级,比如根据缺陷的数量或大小决定分级。 另外,用户可能关心实际应用中的注意事项,比如数据集的多样性和标注质量,模型调参的技巧,以及如何处理不同农产品品种的差异。需要强调数据的重要性,因为模型性能很大程度上依赖于标注数据的质量和数量。可能还需要建议使用数据增强技术来增加数据多样性,防止过拟合。 最后,总结整个流程,确保用户有一个清晰的步骤指南,从数据准备到部署模型。同时,提醒用户根据具体需求调整模型结构和参数,比如不同的YOLO版本在速度和精度上的权衡,以及硬件配置对训练时间的影响。</think>我将分步骤说明如何通过Python实现YOLO算法进行农产品品质分级,包含关键技术实现细节: 1. **环境准备** ```bash pip install torch torchvision ultralytics # 安装PyTorchYOLOv8官方库 ``` 2. **数据集准备** - 目录结构示例: ``` dataset/ ├── images/ │ ├── train/ # 训练图片 │ └── val/ # 验证图片 └── labels/ ├── train/ # YOLO格式标注文件 └── val/ ``` 3. **数据配置文件(data.yaml)** ```yaml path: ./dataset train: images/train val: images/val names: 0: class_A # 分级类别定义(如:特级、一级、二级) 1: class_B 2: defect # 瑕疵类别 ``` 4. **模型训练核心代码** ```python from ultralytics import YOLO # 加载预训练模型(推荐YOLOv8n) model = YOLO('yolov8n.pt') # 训练参数配置 results = model.train( data='data.yaml', epochs=100, imgsz=640, batch=16, device='0', # 使用GPU加速 optimizer='Adam', lr0=0.001, augment=True # 启用数据增强 ) ``` 5. **品质分级推理实现** ```python def quality_grading(img_path): model = YOLO('best.pt') # 加载训练好的模型 results = model.predict(img_path, conf=0.6) grading_result = { 'size': None, 'defects': [], 'grade': '合格' } # 解析检测结果 for box in results[0].boxes: cls = int(box.cls) conf = float(box.conf) if results[0].names[cls] == 'defect': grading_result['defects'].append({ 'type': '表面瑕疵', 'confidence': conf, 'position': box.xywhn.tolist()[0] }) else: grading_result['size'] = box.xywhn.tolist()[0][2:] # 分级逻辑(示例) if len(grading_result['defects']) > 3: grading_result['grade'] = '三级' elif 1 < len(grading_result['defects']) <= 3: grading_result['grade'] = '二级' elif len(grading_result['defects']) == 0: grading_result['grade'] = '特级' return grading_result ``` **关键技术点说明**: 1. **多尺度特征融合**:YOLO通过FPN+PAN结构实现多尺度检测,有效识别不同大小的农产品特征 2. **数据增强策略**: - 颜色空间变换(HSV调整) - 随机裁剪与拼接(Mosaic) - 混合背景生成(MixUp) 3. **损失函数优化**: $$ \text{Loss} = \lambda_{\text{coord}}\sum_{i=0}^{S^2}\sum_{j=0}^{B} \mathbb{1}_{ij}^{\text{obj}}[(x_i-\hat{x}_i)^2 + (y_i-\hat{y}_i)^2] + \cdots $$ 4. **部署优化技巧**: - 使用TensorRT加速推理 - 量化压缩模型(FP16/INT8) - 多线程预处理流水线 **典型应用场景**: - 苹果表面瑕疵检测(黑斑、裂纹) - 柑橘尺寸分级(直径测量) - 大米品质分类(整粒/碎粒识别) - 禽蛋形状检测(椭圆度计算) **注意事项**: 1. 标注数据需包含不同光照条件下的样本 2. 不同品种农产品应分别建立分类体系 3. 定期使用难例样本更新训练集 4. 生产环境需设置置信度阈值过滤误检 完整项目应包含数据采集模块、自动标注工具、模型更新流水线和可视化质量报告系统。实际部署时建议使用工业相机(如Basler)配合触发采集装置,确保图像质量一致性。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

清梦枕星河~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值