[深度学习][原创]让yolov6-0.1.0支持yolov5的txt读取数据集模式

美团出了一个yolov6框架目前看来很不错,由于没出来多久,有很多没有完善。今天特意训练自己的数据集发现这个框架只能是按照这个模式摆放:

custom_dataset
├── images
│   ├── train
│   │   ├── train0.jpg
│   │   └── train1.jpg
│   ├── val
│   │   ├── val0.jpg
│   │   └── val1.jpg
│   └── test
│       ├── test0.jpg
│       └── test1.jpg
└── labels
    ├── train
    │   ├── train0.txt
    │   └── train1.txt
    ├── val
    │   ├── val0.txt
    │   └── val1.txt
    └── test
        ├── test0.txt
        └── test1.txt

而我更喜欢yolov5的模式,当然yolov5也是支持上面摆放模式。

images-

                1.jpg

                2.jpg

               ......

labels-

               1.txt

               2.txt

               .......

然后把分割数据集放txt里面

train.txt

/home/fut/data/images/1.jpg

/home/fut/data/images/2.jpg

....

val.txt

/home/fut/data/images/6.jpg

/home/fut/data/images/7.jpg

....

在配置文件这么配置:

train: myproj/config/train.txt
val: myproj/config/val.txt
nc: 2
# whether it is coco dataset, only coco dataset should be set to True.
is_coco: False

# class names
names: ['dog','cat']

这样就不用每次切割四个文件夹了。话不多说开始改代码,我们打开YOLOv6-0.1.0/yolov6/data/datasets.py修改

def get_imgs_labels(self, img_dir):这个函数加载模式即可。下面是这个函数修改后完整代码
    def get_imgs_labels(self, img_dir):
        NUM_THREADS = min(8, os.cpu_count())
        if os.path.isdir(img_dir):
            valid_img_record = osp.join(
                osp.dirname(img_dir), "." + osp.basename(img_dir) + ".json"
            )


            img_paths = glob.glob(osp.join(img_dir, "*"), recursive=True)
            img_paths = sorted(
                p for p in img_paths if p.split(".")[-1].lower() in IMG_FORMATS
            )
            assert img_paths, f"No images found in {img_dir}."
        else:
            with open(img_dir,'r') as f:
                img_paths = sorted(f.read().rstrip('\n').split('\n'))
            valid_img_record = os.path.dirname(img_dir)+os.sep+'.'+osp.basename(img_dir)[:-4] + ".json"
        img_hash = self.get_hash(img_paths)
        if osp.exists(valid_img_record):
            with open(valid_img_record, "r") as f:
                cache_info = json.load(f)
                if "image_hash" in cache_info and cache_info["image_hash"] == img_hash:
                    img_info = cache_info["information"]
                else:
                    self.check_images = True
        else:
            self.check_images = True

        # check images
        if self.check_images and self.main_process:
            img_info = {}
            nc, msgs = 0, []  # number corrupt, messages
            LOGGER.info(
                f"{self.task}: Checking formats of images with {NUM_THREADS} process(es): "
            )
            with Pool(NUM_THREADS) as pool:
                pbar = tqdm(
                    pool.imap(TrainValDataset.check_image, img_paths),
                    total=len(img_paths),
                )
                for img_path, shape_per_img, nc_per_img, msg in pbar:
                    if nc_per_img == 0:  # not corrupted
                        img_info[img_path] = {"shape": shape_per_img}
                    nc += nc_per_img
                    if msg:
                        msgs.append(msg)
                    pbar.desc = f"{nc} image(s) corrupted"
            pbar.close()
            if msgs:
                LOGGER.info("\n".join(msgs))

            cache_info = {"information": img_info, "image_hash": img_hash}
            # save valid image paths.
            with open(valid_img_record, "w") as f:
                json.dump(cache_info, f)
       
        # # check and load anns
        # label_dir = osp.join(
        #     osp.dirname(osp.dirname(img_dir)), "coco", osp.basename(img_dir)
        # )
        # assert osp.exists(label_dir), f"{label_dir} is an invalid directory path!"

        img_paths = list(img_info.keys())
        label_dir = os.path.dirname(img_paths[0]).replace('images', 'labels')
        label_paths = sorted(
            osp.join(label_dir, osp.splitext(osp.basename(p))[0] + ".txt")
            for p in img_paths
        )
        label_hash = self.get_hash(label_paths)
        if "label_hash" not in cache_info or cache_info["label_hash"] != label_hash:
            self.check_labels = True

        if self.check_labels:
            cache_info["label_hash"] = label_hash
            nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number corrupt, messages
            LOGGER.info(
                f"{self.task}: Checking formats of labels with {NUM_THREADS} process(es): "
            )
            with Pool(NUM_THREADS) as pool:
                pbar = pool.imap(
                    TrainValDataset.check_label_files, zip(img_paths, label_paths)
                )
                pbar = tqdm(pbar, total=len(label_paths)) if self.main_process else pbar
                for (
                    img_path,
                    labels_per_file,
                    nc_per_file,
                    nm_per_file,
                    nf_per_file,
                    ne_per_file,
                    msg,
                ) in pbar:
                    if nc_per_file == 0:
                        img_info[img_path]["labels"] = labels_per_file
                    else:
                        img_info.pop(img_path)
                    nc += nc_per_file
                    nm += nm_per_file
                    nf += nf_per_file
                    ne += ne_per_file
                    if msg:
                        msgs.append(msg)
                    if self.main_process:
                        pbar.desc = f"{nf} label(s) found, {nm} label(s) missing, {ne} label(s) empty, {nc} invalid label files"
            if self.main_process:
                pbar.close()
                with open(valid_img_record, "w") as f:
                    json.dump(cache_info, f)
            if msgs:
                LOGGER.info("\n".join(msgs))
            if nf == 0:
                LOGGER.warning(
                    f"WARNING: No labels found in {osp.dirname(self.img_paths[0])}. "
                )

        if self.task.lower() == "val":
            if self.data_dict.get("is_coco", False): # use original json file when evaluating on coco dataset.
                assert osp.exists(self.data_dict["anno_path"]), "Eval on coco dataset must provide valid path of the annotation file in config file: data/coco.yaml"
            else:
                assert (
                    self.class_names
                ), "Class names is required when converting labels to coco format for evaluating."
                save_dir = osp.join(osp.dirname(osp.dirname(img_dir)), "annotations")
                if not osp.exists(save_dir):
                    os.mkdir(save_dir)
                save_path = osp.join(
                    save_dir, "instances_" + osp.basename(img_dir) + ".json"
                )
                TrainValDataset.generate_coco_format_labels(
                    img_info, self.class_names, save_path
                )

        img_paths, labels = list(
            zip(
                *[
                    (
                        img_path,
                        np.array(info["labels"], dtype=np.float32)
                        if info["labels"]
                        else np.zeros((0, 5), dtype=np.float32),
                    )
                    for img_path, info in img_info.items()
                ]
            )
        )
        self.img_info = img_info
        LOGGER.info(
            f"{self.task}: Final numbers of valid images: {len(img_paths)}/ labels: {len(labels)}. "
        )
        return img_paths, labels

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

FL1623863129

你的打赏是我写文章最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值