YOLOv8源码修改(1)- DataLoader增加负样本数据读取+平衡训练batch中的正负样本数

背景

实际工程中,仅仅使用采集到的数据(或者说单一的数据集)进行训练,会导致网络在真实未知的场景下,极其容易发生误检。因此,网络需要见识不同的场景数据才更具有鲁棒性。

YOLOv8官方代码

1.修改思路

改进点1:网络读取额外的数据集图片。

改进点2:调整每个epoch训练的图片数量。

改进点3:平衡总体数据集训练读取时,每一个batch的正负样本比例。

2.涉及修改相关的文件

模型配置文件:ultralytics/cfg/models/v8/yolov8.yaml

数据配置文件:ultralytics/cfg/datasets/*.yaml

数据集读取基类:ultralytics/data/base.py

数据集读取子类:ultralytics/data/dataset.py

3.配置文件修改

(1)模型配置文件:ultralytics/cfg/models/v8/yolov8.yaml

nc: 4        # number of classes 

仅需修改类别数。

(2)数据配置文件:ultralytics/cfg/datasets/*.yaml

# path: /path_to_root_dir/data_without_neg # dataset root dir
train: /path_to_train_data/data_without_neg/train/images # train images
val: /path_to_val_data/data_without_neg/val/images # val images
# test: test/images # test images

# 增加的额外参数用于控制负样本数据读取和比例
negative_setting:
  neg_ratio: 5    # 小于等于0时,按原始官方配置训练,大于0时,控制正负样本。
  use_extra_neg: True    # 是否使用额外数据
  extra_neg_sources: {
                          "/path_to_extra_data_1/COCO/train2014" : 20000,    
                          # "/path_to_extra_data_2/image_list.txt": 10
                      }  # 存储为字典(图片文件夹或图片列表文件),{路径:读取数量}
  fix_dataset_length: 10000  # 是否自定义每轮参与训练的图片数量


# number of classes
nc: 4

# Classes
names:
  0: ycj
  1: kx
  2: kx_dk
  3: money
  1. trainval指定训练和验证集文件。
  2. negative_setting下的参数用于控制Dataloader的负样本读取。
  3. neg_ratio的值表示一个batch里“负样本:正样本”的比值。如为5时,batchsize=16,一个batch大约3张正样本。
  4. use_extra_neg用于控制是否使用额外数据。
  5. extra_neg_source用于指定数据读取路径,后面数字表示该路径下采样的图片数量。
  6. fix_dataset_length表示训练时每一个epoch的数据集实际长度。如长度为10000,正样本读取1666张,负样本读取8334张,接近比值5。(实际会受线程数的影响,读取数量略有差异。)

4.dataloader修改

4.1 基类BaseDataset

YOLOv8的dataloader的具体实现由任务决定,但都继承自ultralytics/data/base.py下的基类BaseDataset(这部分不需要修改,修改会影响其他任务的数据读取,但是需要了解他的代码逻辑才能知道整个数据读取的实现方式)。

class BaseDataset(Dataset):
    def __init__(
        self,
        img_path,
        imgsz=640,
        cache=False,
        augment=True,
        hyp=DEFAULT_CFG,
        prefix="",
        rect=False,
        batch_size=16,
        stride=32,
        pad=0.5,
        single_cls=False,
        classes=None,
        fraction=1.0
    ):
        """Initialize BaseDataset with given configuration and options."""
        super().__init__()
        self.img_path = img_path        # 图像文件夹路径
        self.imgsz = imgsz              # 图像大小
        self.augment = augment          # 是否进行数据增强
        self.single_cls = single_cls    # 是否进行单类别训练
        self.prefix = prefix            # 输出日志信息的前缀
        self.fraction = fraction        # 数据集的使用比例
        self.im_files = self.get_img_files(self.img_path)   # 获取图像文件列表
        self.labels = self.get_labels()                     # 获取标签信息
        self.update_labels(include_class=classes)           # 更新标签信息,包括指定的类别
        self.ni = len(self.labels)                          # 数据集中图像的数量
        self.rect = rect                # 是否使用矩形训练
        self.batch_size = batch_size    # 批次大小
        self.stride = stride            # 步长
        self.pad = pad                  # 填充值
        if self.rect:                   # 如果使用矩形训练,确保批次大小不为None
            assert self.batch_size is not None
            self.set_rectangle()        # 设置矩形训练相关参数

        # Buffer thread for mosaic images
        self.buffer = []            # buffer size = batch size, 缓冲区大小为批次大小
        self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0

        # Cache images (options are cache = True, False, None, "ram", "disk")
        self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
        self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
        self.cache = cache.lower() if isinstance(cache, str) else "ram" if cache is True else None
        if (self.cache == "ram" and self.check_cache_ram()) or self.cache == "disk":
            self.cache_images()     # 缓存图像到内存或磁盘

        # Transforms
        self.transforms = self.build_transforms(hyp=hyp)    # 构建图像变换函数

    def __getitem__(self, index):
        # 根据索引获取每一张图片,pytorch中继承得到重写
        return self.transforms(self.get_image_and_label(index))

    def __len__(self):
        # 获取实际数据集长度,也是pytorch中继承得到重写
        return len(self.labels)

    def build_transforms(self, hyp=None):
        # 图片转换方式,需要自己实现,不实现抛出异常
        raise NotImplementedError

    def get_labels(self):
        # 获取标签(实际是把图片+标签打包成字典方便读取),需要自己实现,不实现抛出异常
        raise NotImplementedError

根据基类的方法,知道了:

  1. 控制数据正负样本比在__getitem__中实现。
  2. 控制数据集长度在__len__中实现。
  3. 添加额外初始化信息在__init__中实现。
  4. 读取额外数据、获取数据信息在__get_labels__中实现。

4.2 子类YOLODataset

修改后的关键代码:

class YOLODataset(BaseDataset):
    def __init__(self, *args, data=None, task="detect", **kwargs):
        """Initializes the YOLODataset with optional configurations for segments and keypoints."""
        self.use_segments = task == "segment"
        self.use_keypoints = task == "pose"
        self.use_obb = task == "obb"
        self.data = data

        self.im_pos_index = []      # 正样本下标索引
        self.im_neg_index = []      # 负样本下标索引
        self.im_pos_num = 0         # 正样本数量
        self.im_neg_num = 0         # 负样本数量
        self.img_neg_path = ""      # 负样本路径
        self.im_neg_files = []      # 负样本文件列表

        assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
        super().__init__(*args, **kwargs)

    def __getitem__(self, index):
        """Returns transformed label information for given index."""
        # pid = os.getpid()
        if "train" in self.prefix.lower():
            if self.im_pos_num * self.data["negative_setting"]["neg_ratio"] >= self.im_neg_num:
                self.im_neg_num += 1
                index = random.choice(self.im_neg_index)
                # print(f"选择负样本,当前index为: {index}")
            else:
                self.im_pos_num += 1
                index = random.choice(self.im_pos_index)
                # print(f"选择正样本,当前index为: {index}")
        # print(f"当前线程:[{pid}],当前index为: {index},已经选择正样本数:{self.im_pos_num},负样本数{self.im_neg_num}")
        return self.transforms(self.get_image_and_label(index))

    def __len__(self):
        try:
            if "train" in self.prefix.lower() and self.data["negative_setting"]["fix_dataset_length"] > 0:
                return int(self.data["negative_setting"]["fix_dataset_length"])
        except (ValueError, KeyError, AttributeError) as e:
            print(f"INFO: 设置每个epoch长度失败,使用原始数据集长度。发生的错误为:{e}")

        return len(self.labels)

    def get_labels(self):
        """Returns dictionary of labels for YOLO training."""
        try:
            # 额外增加负样本,仅训练时增加
            if "train" in self.prefix.lower() and self.data["negative_setting"]["use_extra_neg"]:
                self.img_neg_path = self.data["negative_setting"]["extra_neg_sources"]         # 负样本文件列表
                for imp, imn in self.img_neg_path.items():
                    imp_neg_file = self.get_img_files(imp)      # 一个文件夹下的有效图片列表
                    imn_real = min(len(imp_neg_file), imn)      # 实际应该读取的文件数

                    print(f'INFO: 额外增加的负样本:[{imp}], 有[{len(imp_neg_file)}]张图片,应该抽取[{imn}]张图片,'
                          f'实际随机抽取[{imn_real}]张图片。')
                    imp_neg_file = random.sample(imp_neg_file, imn_real)    # 一个文件夹下实际采样的图片列表
                    self.im_neg_files += imp_neg_file

                print(f"INFO: 总共实际获取的负样本有:[{len(self.im_neg_files)}]张图片。")

        except (ValueError, KeyError, AttributeError) as e:
            print(f"INFO: 读取额外负样本数据失败,不增加负样本。配置文件中[negative_setting]存在错误:{e}")
            print(f"INFO: 总共实际获取的负样本有:[{len(self.im_neg_files)}]张图片。")

        self.im_files += self.im_neg_files
        self.label_files = img2label_paths(self.im_files)
        cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
        try:
            cache, exists = load_dataset_cache_file(cache_path), True  # attempt to load a *.cache file
            assert cache["version"] == DATASET_CACHE_VERSION  # matches current version
            assert cache["hash"] == get_hash(self.label_files + self.im_files)  # identical hash
        except (FileNotFoundError, AssertionError, AttributeError):
            cache, exists = self.cache_labels(cache_path), False  # run cache ops

        # Display cache
        nf, nm, ne, nc, n = cache.pop("results")  # found, missing, empty, corrupt, total
        if exists and LOCAL_RANK in {-1, 0}:
            d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
            print(f"")
            TQDM(None, desc=self.prefix + d, total=n, initial=n)  # display results
            if cache["msgs"]:
                LOGGER.info("\n".join(cache["msgs"]))  # display warnings

        # Read cache
        [cache.pop(k) for k in ("hash", "version", "msgs")]  # remove items
        labels = cache["labels"]
        if not labels:
            LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}")
        self.im_files = [lb["im_file"] for lb in labels]  # update im_files

        # Check if the dataset is all boxes or all segments
        lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
        len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
        if len_segments and len_boxes != len_segments:
            LOGGER.warning(
                f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
                f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
                "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
            )
            for lb in labels:
                lb["segments"] = []
        if len_cls == 0:
            LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}")

        # 仅在train模式下,增加正负样本索引
        if "train" in self.prefix.lower():
            for i, label in enumerate(labels):
                if len(label['cls']) == 0:
                    self.im_neg_index.append(i)
                else:
                    self.im_pos_index.append(i)

        return labels

4.2.1 __init__修改

"""
这部分参数一定要加在super之前,因为BaseDataset的__init__中,
会调用self.labels = self.get_labels(),而所增加的额外信息,
在该方法中会使用到。

"""
self.im_pos_index = []      # 正样本下标索引
self.im_neg_index = []      # 负样本下标索引
self.im_pos_num = 0         # 正样本数量
self.im_neg_num = 0         # 负样本数量
self.img_neg_path = ""      # 负样本路径
self.im_neg_files = []      # 负样本文件列表


super().__init__(*args, **kwargs)
"""
继承时触发:
self.labels = self.get_labels()    # 获取标签信息
"""

4.2.2 get_labels修改

"""
注意只有在训练时才需要增加额外负样本,通过self.prefix(用于命令行输出信息的字符串),可以知道当前是train/val/test。
再通过基类方法self.get_img_files读取;根据设定的图片数量,用random.sample采样指定数量图片。
"""
try:
    # 额外增加负样本,仅训练时增加
    if "train" in self.prefix.lower() and self.data["negative_setting"]["use_extra_neg"]:
        self.img_neg_path = self.data["negative_setting"]["extra_neg_sources"]         # 负样本文件列表
        for imp, imn in self.img_neg_path.items():
            imp_neg_file = self.get_img_files(imp)      # 一个文件夹下的有效图片列表
            imn_real = min(len(imp_neg_file), imn)      # 实际应该读取的文件数

            print(f'INFO: 额外增加的负样本:[{imp}], 有[{len(imp_neg_file)}]张图片,应该抽取[{imn}]张图片,'
                  f'实际随机抽取[{imn_real}]张图片。')
            imp_neg_file = random.sample(imp_neg_file, imn_real)    # 一个文件夹下实际采样的图片列表
            self.im_neg_files += imp_neg_file

        print(f"INFO: 总共实际获取的负样本有:[{len(self.im_neg_files)}]张图片。")

except (ValueError, KeyError, AttributeError) as e:
    print(f"INFO: 读取额外负样本数据失败,不增加负样本。配置文件中[negative_setting]存在错误:{e}")
    print(f"INFO: 总共实际获取的负样本有:[{len(self.im_neg_files)}]张图片。")

"""
该索引用于__getitem__选择正样本还是负样本。
"""
# 仅在train模式下,增加正负样本索引
if "train" in self.prefix.lower():
    for i, label in enumerate(labels):
        if len(label['cls']) == 0:
            self.im_neg_index.append(i)
        else:
            self.im_pos_index.append(i)

4.2.3 __getitem__修改

def __getitem__(self, index):
"""Returns transformed label information for given index."""
"""
只有train时,才控制正负样本,维护正负样本数量self.im_pos_num和self.im_neg_num的比例关系。
注意:用乘法不要用除法,除法还要处理分母为0等问题。
多线程下,还要注意random.choice的伪随机,是否导致不同线程采样一样的问题。
"""
# pid = os.getpid()
if "train" in self.prefix.lower():
    if self.im_pos_num * self.data["negative_setting"]["neg_ratio"] >= self.im_neg_num:
        self.im_neg_num += 1
        index = random.choice(self.im_neg_index)
        # print(f"选择负样本,当前index为: {index}")
    else:
        self.im_pos_num += 1
        index = random.choice(self.im_pos_index)
        # print(f"选择正样本,当前index为: {index}")
# print(f"当前线程:[{pid}],当前index为: {index},已经选择正样本数:{self.im_pos_num},负样本数{self.im_neg_num}")
return self.transforms(self.get_image_and_label(index))

查看多线程正负样本选择情况,正负样本比为5时:

4.2.4 __len__修改

train模式下且设置了长度,返回该设置长度即可。

def __len__(self):
try:
    if "train" in self.prefix.lower() and self.data["negative_setting"]["fix_dataset_length"] > 0:
        return int(self.data["negative_setting"]["fix_dataset_length"])
except (ValueError, KeyError, AttributeError) as e:
    print(f"INFO: 设置每个epoch长度失败,使用原始数据集长度。发生的错误为:{e}")

return len(self.labels)

5.修改效果

5.1 修改前

查看train_batch.jpg,所有图片均有标签(原始数据集不含负样本,这里使用了马赛克增强,导致图片看起来混乱):

5.2 修改后

读取到的图片数量:

每个epoch实际训练使用的图片大小,设置trian的batchsize=32,fix_dataset_len=640:

额外加入COCO数据集作为负样本。

ratio值为5:

ratio值为1:

6.训练效果

使用的模型是yolov8s.pt,均使用预训练模型,训练轮次200,batchsize=32。

从P-R曲线可以看到,mAP略有提高。

从混淆矩阵可以看到,我的这个训练场景“漏检”和“误检”都有所降低,特别是“背景”误检为“money”的情况(因为一沓钱和一沓纸很容易搞混,加入负样本提高了识别准确率)。(实际情况,加入负样本一般只降低误检率,会提高漏检率。)

6.1 修改前

6.2 修改后

训练 YOLOv5 模型时,如果无法加载数据集,可能是由以下几个原因导致的: 1. 数据集路径错误:确保你在训练脚本正确指定了数据集的路径。检查数据集是否存在,并且路径是否正确。 2. 数据集格式错误:YOLOv5 要求数据集使用特定的格式,通常是 COCO 格式或类似的格式。确保你的数据集按照正确的格式进行组织,并且包含必要的标注和图像文件。 3. 数据集加载代码错误:检查训练脚本加载数据集的代码,确保没有语法错误或逻辑错误。常见的加载数据集的代码片段如下所示: ```python from torch.utils.data import DataLoader from torchvision import datasets # 加载数据集 dataset = datasets.CocoDetection(root='path/to/dataset', annFile='path/to/annotations.json', ...) dataloader = DataLoader(dataset, batch_size=..., shuffle=..., num_workers=...) ``` 请根据你的实际情况修改上述代码,确保正确加载数据集并设置适当的参数。 4. 数据集预处理错误:在加载数据集后,通常需要进行一些预处理操作,如图像缩放、归一化、数据增强等。确保你在训练脚本正确实现了这些预处理操作。 5. 数据集标注错误:如果数据集标注有问题,如格式错误、缺失标注或标注与图像不匹配等,可能会导致无法加载数据集。检查数据集标注文件的正确性,并确保每个图像都有相应的正确标注。 如果你仍然无法解决数据集加载的问题,你可以提供更多的细节和错误信息,以便我能够更具体地帮助你。
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值