YOLO 数据增强 多尺度训练(将原来较大数据拆分为多个小的数据进行训练),实测效果较为显著

73 篇文章 0 订阅
43 篇文章 0 订阅
import os
import cv2
from tqdm import tqdm


def get_imgs_pos(img_w, img_h, cut_w, cut_h, w_stride, h_stride):
    imgs_pos = []
    for beg_w in range(0, img_w, w_stride):
        for beg_h in range(0, img_h, h_stride):
            x0, y0 = beg_w, beg_h  # 左上角的点
            x1, y1 = beg_w + cut_w, beg_h + cut_h
            if x1 > img_w:  # x轴上超出图像边界
                x1 = img_w
                x0 = img_w - cut_w
            if y1 > img_h: # y轴上超出图像边界
                y1 = img_h
                y0 = img_h - cut_h
            imgs_pos.append([x0, y0, x1, y1])
            if y1 == img_h:  # 如果超出边界
                break
    return imgs_pos


def save_subimg(cv_img, pos, img_save_dir, img_name, idx):
    x0, y0, x1, y1 = pos
    crop_img = cv_img[y0:y1, x0:x1]
    cv2.imwrite(os.path.join(img_save_dir, img_name[0:-4] + "_" + "{:04d}".format(idx) + ".jpg"), crop_img)


def save_sublabs(sub_labels, label_save_dir, img_name, idx):
    lab_path = os.path.join(label_save_dir, img_name[0:-4] + "_" + "{:04d}".format(idx) + ".txt")
    with open(lab_path, 'w') as fw:
        for lab in sub_labels:
            line = " ".join(str(num) for num in lab)
            fw.write(line + "\n")



def read_labels(txt_path):
    pos = []
    with open(txt_path, 'r') as file_to_read:
        while True:
            lines = file_to_read.readline()  # 整行读取数据
            if not lines:
                break
                pass
            p_tmp = [float(i) for i in lines.split(' ')] 
            pos.append(p_tmp)  # 添加新读取的数据
            pass
    return pos


def get_sublabels(pos, labels, img_w, img_h, cut_w, cut_h):
    x0, y0, x1, y1 = pos  # 得到该子图在大图上的位置,左上角和右下角的坐标
    sub_labs = []
    for lab in labels:
        cx, cy, w, h = lab[1] * img_w, lab[2] * img_h, lab[3] * img_w, lab[4] * img_h  # 换算得到真实的中心点及宽高,注意第一个是标签的类别
        if x0 < cx < x1 and y0 < cy < y1:  # 如果该标签的中心点落到了子图上
            # 如果当前的标签落到了子图像的边界上, 处理该标签在子图上的宽的问题
            if cx - x0 < w / 2: 
                w = w / 2 + (cx - x0)
            if x1 - cx < w / 2:
                w = w / 2 + (x1 - cx)
            # 如果当前的标签落到了子图像的边界上, 处理该标签在子图上的高的问题
            if cy - y0 < h / 2:
                h = h / 2 + (cy - y0)
            if y1 - cy < h / 2:
                h = h / 2 + (y1 - cy)
            cx, cy = cx - x0, cy - y0  #  将当前的坐标换算到子图上(宽高不变,只是中心点的位置发生了改变)
            sub_labs.append([int(lab[0]), cx / cut_w, cy / cut_h, w / cut_w, h / cut_h])  #重新归一化  
    return sub_labs

if __name__ == '__main__':
    img_dir = "/data/jjg/codes/datasets/poppy/images/train/"
    img_list = os.listdir(img_dir)
    img_save_dir = "/data/jjg/codes/datasets/poppy-ms/images/train/"
    label_dir = "/data/jjg/codes/datasets/poppy/labels/train/"
    label_save_dir = "/data/jjg/codes/datasets/poppy-ms/labels/train/"
    cut_w = 1280
    cut_h = 720
    w_stride = 1000  # 注意这里的设置可能存在点技巧,多根据自己的数据进行修改和尝试
    h_stride = 500  # 注意这里的设置可能存在点技巧,多根据自己的数据进行修改和尝试
    count = 0
    for img_name in tqdm(img_list):
        if img_name.endswith((".jpg", ".JPG")):
            img_path = os.path.join(img_dir, img_name)
            cv_img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
            img_w, img_h = cv_img.shape[1], cv_img.shape[0]
            if img_w > cut_w and img_h > cut_h: # 如果原图的大小是大于需要裁剪的图像大小
                imgs_pos = get_imgs_pos(img_w, img_h, cut_w, cut_h, w_stride, h_stride)
                if len(imgs_pos):  # 如果原图像被拆分为了多个子图像
                    labels = read_labels(os.path.join(label_dir, img_name[0:-4] + ".txt"))
                    # print(labels)
                    for idx, pos in enumerate(imgs_pos):  # 逐个对所有子图像寻找其图上的子lables
                        sub_labels = get_sublabels(pos, labels, img_w, img_h, cut_w, cut_h)  # 找到当前子图对应的所有lables
                        if len(sub_labels):  # 如果该子图像中存在目标标签
                            count += 1
                            save_subimg(cv_img, pos, img_save_dir, img_name, idx)  # 保存该子图像
                            save_sublabs(sub_labels, label_save_dir, img_name, idx)  # 保存该子图像的标签
    print("has generated", count, "images.")

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值