基于Copy-paste的图像P图示例

原理:将待P图的所有目标抠取下来(目标可进行一定的数据增强:翻转、旋转、对比度变化等)进行一定比例缩放,粘贴到待P原图设置的区间范围内。

1、P图步骤

(1)搜集待P的物体目标;
(2)利用X-Anylabelimg等标注软件对目标进行标注,保存一个json文件格式的掩码信息;
(3)设置参数,包括:原图、待P目标、待P目标掩码信息、保存位置、目标缩放区间、目标贴图区间、生成个数。
注意:缩放系数与待P目标图分辨率有关,不同分辨率缩放系数不一样。

# pip install imgaug==0.4.0
# pip install opencv-pytho==4.8.1.78
# pip install numpy ==1.22.4

import argparse
import os
import numpy as np
import cv2
import json
import imgaug.augmenters as iaa


# 图间隔采样缩放,比resize好
def zoom_out_equal_interval(img, new_h, new_w, rescale_ratio=0.3, PS=False, x=0, y=0):
    '''
    图像缩小—等间隔采样
    :param img: 原图
    :param new_h, new_w: 缩放后的大小
    :param rescale_ratio: 行和列缩放比例
    :param PS: 是否进行位置移动
    :param x, y: 移动的位置
    :return:
    '''
    img_h, img_w, channels = img.shape  # 获取图像大小
    # print(new_h,new_w)
    zoom_img = np.zeros((new_h, new_w, channels), dtype=np.uint8)  # 创建缩放后的图片大小的矩阵
    sampling_interval_h = rescale_ratio
    sampling_interval_w = rescale_ratio
    if PS:
        for i in range(img_h):
            for j in range(img_w):
                h = min(new_h-1, int(i * sampling_interval_h - 0.5) + y)  # 按比例放大或者缩小
                w = min(new_w-1, int(j * sampling_interval_w - 0.5) + x)
                zoom_img[h][w] = img[i][j]
    else:
        for i in range(img_h):
            for j in range(img_w):
                h = min(new_h - 1, int(i * sampling_interval_h - 0.5))  # 按比例放大或者缩小
                w = min(new_w - 1, int(j * sampling_interval_w - 0.5))
                zoom_img[h][w] = img[i][j]
    return zoom_img


# 随机水平翻转+旋转+对比度
def random_flip_horizontal(img, mask_imgs=None, p=0.5):
    '''
    随机水平翻转,并处理bbox和seg标签,翻转概率默认为0.5
    img: 原始的完整图片
    mask_img: 当前图片中anns截取下来的图片的List
    '''
    # 一定概率
    if np.random.random() < p:
        new_mask_imgs = []
        
        # 原图
        if mask_imgs == None:
            img = img[:, ::-1, :]
            return img, mask_imgs
        
        # 掩码
        else:
            for i, mask_img in enumerate(mask_imgs):
                # 翻转
                mask_img = mask_img[:, ::-1, :]

                # 旋转
                rot1 = iaa.Affine(rotate=(-60, 60))
                mask_img = rot1.augment_image(mask_img)

                # 对比度
                contrast = iaa.GammaContrast((0.5, 2.0))
                mask_img = contrast.augment_image(mask_img)

                new_mask_imgs.append(mask_img)
            return img, new_mask_imgs
    else:
        return img, mask_imgs


def LSJ(img, mask_imgs, rescale_ratio):
    '''
    Large-Scaling-Jittering
    img: 原始的完整图片
    mask_img: 当前图片中anns截取下来的图片的List
    '''
    h, w, _ = img.shape
    # 得到新尺寸
    h_new, w_new = int(h * rescale_ratio), int(w * rescale_ratio)

    new_mask_imgs = []
    for i, mask_img in enumerate(mask_imgs):
        if rescale_ratio <= 1.0:
            mask_img = zoom_out_equal_interval(mask_img, h_new, w_new, rescale_ratio)  # 缩放++++++++++++
        else:
            mask_img = cv2.resize(mask_img, (w_new, h_new), interpolation=cv2.INTER_NEAREST)
        new_mask_imgs.append(mask_img)

    return new_mask_imgs


def add_mask_to_img(main_img, mask_imgs, x, y):
    '''
    将新的mask添加当img上
    输入:
    1.main_img: 在此图片上添加mask图片
    2.mask_imgs: list,其中每个元素为背景为黑色(0),只截取了目标部分的图像
    '''
    # 获取main_img尺寸
    if len(main_img.shape) == 3:
        h, w, c = main_img.shape
    else:
        h, w = main_img.shape
    # ---------------------------------------------#
    #   扣掉main_img对应位置的图像,然后贴上mask的图片
    # ---------------------------------------------#

    for i, mask_img in enumerate(mask_imgs):
        tmp = 0
        if i >= 1:
            tmp = int(np.random.uniform(0, 100))
        mask_img_pad = zoom_out_equal_interval(mask_img, h, w, rescale_ratio=1, PS=True, x=x + tmp, y=y + tmp)  # 缩放++++++++++++

        mask = cv2.cvtColor(mask_img_pad, cv2.COLOR_BGR2GRAY)  # 获得单通道灰度mask_img
        ret, mask = cv2.threshold(mask, 0.0000000001, 255, cv2.THRESH_BINARY)  # 二值化处理
        mask_inv = cv2.bitwise_not(mask)  # 非运算,mask取反,用于扣掉图片
        main_img = cv2.bitwise_and(main_img, main_img, mask=mask_inv)  # 删除了img中的mask_inv区域
        main_img = cv2.add(main_img, mask_img_pad)

    return main_img


def copy_paste(main_img, src_img, src_mask_imgs, rescale_ratio, x, y):
    '''
    Copy-Paste的主体实现,主要流程如下:
    1.random_flip
    2.Large Scale Jittering

    输入:1.main_img, src_img : 图片
         2.src_mask_imgs:list,存储各个依靠seg截取下来的目标的图像,各自都是独立的
    输出:copy_paste_img
    '''
    # random flip
    main_img, main_mask_imgs = random_flip_horizontal(main_img)
    src_img, src_mask_imgs = random_flip_horizontal(src_img, src_mask_imgs)

    # LSJ src大小缩放
    src_mask_imgs = LSJ(src_img, src_mask_imgs, rescale_ratio)

    # 将src_mask粘贴到main_img上
    main_img = add_mask_to_img(main_img, src_mask_imgs, x, y)

    return main_img


def main(args):
    '''
    Copy-Paste实现流程:
    1.读取数据
    2.随机选取src/main img以及src的annotation
    3.Copy-Paste
    4.后续处理图片、生成Json文件
    '''
    # 创建保存路径
    if not os.path.exists(args.output):
        os.makedirs(args.output)

    # --------------------------------------#
    #               1.读取数据
    # --------------------------------------#
    # 读取main的图片
    main_img = cv2.imread(args.main_path)

    # 读取src_img的
    src_img = cv2.imread(args.src_path)
    

    # 获取src中随机选取得到的目标截取下来
    # read json file
    with open(args.src_json, "r") as f:
        data = f.read()

    # convert str to json objs
    data = json.loads(data)
    # --------------------------------------#
    #               2.截取seg下来
    # --------------------------------------#
    src_mask_imgs = []
    # get the points
    for d in data["shapes"]:
        src_img_copy = src_img.copy()
        points = d["points"]
        if len(points) == 0:
            continue
        points = np.array(points, dtype=np.int32)

        # create a blank image
        mask = np.zeros_like(src_img_copy, dtype=np.uint8)

        # fill the contour with 255
        cv2.fillPoly(mask, [points], (255, 255, 255))
        region = mask == 0
        src_img_copy[region] = 0
        mask_img = src_img_copy  # 截取seg下来

        # cv2.imwrite('out0.jpg', mask_img)  # 保存待分割的目标
        src_mask_imgs.append(mask_img)
        
    # --------------------------------------#
    #               3.Copy-Paste
    # --------------------------------------#
    if len(src_mask_imgs) > 0:
        for i in range(int(max(1, args.nums))):
            print("正在生成第 {} 张图...".format(i+1))
            rescale_ratio = round(np.random.uniform(args.min_scale, args.max_scale), 2)
            x = int(np.random.uniform(args.x_scale[0], args.x_scale[1]))
            y = int(np.random.uniform(args.y_scale[0], args.y_scale[1]))
            print('贴图位置x:{}, y:{}, 缩放比率:{}'.format(x, y, rescale_ratio))
            copy_paste_img = copy_paste(main_img, src_img, src_mask_imgs, rescale_ratio, x, y)
            # print(args.main_path.split('/')[-1][0:-4])
            cv2.imwrite(os.path.join(args.output, args.main_path.split('/')[-1][:-4] + '_{}'.format(i+1) + '.jpg'), copy_paste_img)
        print("结束生成!!!")
    else:
        print("没有待P目标!!!") 


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--main_path", default="input/ori_img/test.jpg", type=str,
                        help="需要PS的img路径")
    parser.add_argument("--src_path", default="input/ps_img/p_002.jpg", type=str,
                        help="待贴目标的img路径")
    parser.add_argument("--src_json", default="input/ps_json/p_002.json", type=str,
                        help="待贴目标的json路径")
    parser.add_argument("--output", default="output/", type=str,
                        help="待保存的img路径")

    parser.add_argument("--min_scale", default=0.15, type=float,
                        help="随机缩放区间的最小倍数")
    parser.add_argument("--max_scale", default=0.15, type=float,
                        help="随机缩放区间的最大倍数")
    parser.add_argument("--x_scale", default=[500, 900],
                        help="随机所贴的x位置区间")
    parser.add_argument("--y_scale", default=[500, 500],
                        help="随机所贴的y位置区间")

    parser.add_argument("--nums", default=1, type=int,
                        help="随机生成的img数量")

    return parser.parse_args()


if __name__ == '__main__':
    args = get_args()
    main(args)

2、结果

P图结果如下
在这里插入图片描述
:图中P的是只狗,P图前需要一个掩码标注文件,格式为json。

  • 11
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值