模拟光照数据增强,高效版本

1.前言

在训练神经网络模型时,经常会遇到因为光照、车灯、各种环境灯的影响导致目标漏检,而收集、标定相应的素材非常困难。使用数据增强的手段扩增训练集是一种简单、有效的方法。
参考链接

2. 参考代码

缺点:通过两层for循环遍历计算的方法,会随着图片尺寸的增加而导致耗时骤增,从而导致训练速度低效

#coding:utf-8
import cv2
import math
import numpy as np
import matplotlib.pyplot as plt
#读取原始图像
img = cv2.imread('000100000910000000098.jpg')
 
#获取图像行和列
rows, cols = img.shape[:2]
 
#设置中心点
centerX = rows / 2
centerY = cols / 2
print (centerX, centerY)
radius = min(centerX, centerY)
print (radius)
 
#设置光照强度
strength = 200
 
#图像光照特效
for i in range(rows):
    for j in range(cols):
        #计算当前点到光照中心距离(平面坐标系中两点之间的距离)
        distance = math.pow((centerY-j), 2) + math.pow((centerX-i), 2)
        #获取原始图像
        B =  img[i,j][0]
        G =  img[i,j][1]
        R = img[i,j][2]
        if (distance < radius*radius):
            #按照距离大小计算增强的光照值
            result = (int)(strength*( 1.0 - math.sqrt(distance) / radius ))
            B = img[i,j][0] + result
            G = img[i,j][1] + result
            R = img[i,j][2] + result
            #判断边界 防止越界
            B = min(255, max(0, B))
            G = min(255, max(0, G))
            R = min(255, max(0, R))
            img[i,j] = np.uint8((B, G, R))
        else:
            img[i,j] = np.uint8((B, G, R))
        
#显示图像
cv2.imwrite('test.jpg', img)
plt.imshow(img)
plt.show()

3.代码改进

改进思路是使用opencv实现并行化,提高代码执行速度。
该代码还有改进的空间,但对速度影响不大。

def light_enhance_improve(img, strength):
    rows, cols = img.shape[:2]
    centerX = rows / 2
    centerY = cols / 2
    radius = min(centerX, centerY)
    arr_rows = np.arange(rows).reshape(rows, 1)
    arr_cols = np.arange(cols).reshape(1, cols)
    arr = 1 - np.sqrt(((arr_cols - centerY) ** 2) + ((arr_rows - centerX) ** 2)) / radius
    arr_B = np.maximum(np.int16(strength[0] * arr), 0)
    arr_G = np.maximum(np.int16(strength[1] * arr), 0)
    arr_R = np.maximum(np.int16(strength[2] * arr), 0)
    img = np.int16(img)
    img[:,:,0]+=arr_B
    img[:,:,1]+=arr_G
    img[:,:,2]+=arr_R
    img = np.maximum(img, 0)
    img = np.minimum(img, 255)
    img = np.uint8(img)
    return img

4.整体代码

import numpy as np
import cv2
import time
import math


def get_time(f):
    def inner(*args, **kwargs):
        s_time = time.time()
        res = f(*args, **kwargs)
        e_time = time.time()
        print()
        print('{} function used time: {} (second)'.format(f.__name__, e_time - s_time))
        return res
    return inner


@get_time
def light_enhance_origin(img, strength):
    rows, cols = img.shape[:2]
    # 设置中心点
    centerX = rows / 2
    centerY = cols / 2
    radius = min(centerX, centerY)
    for i in range(rows):
        for j in range(cols):
            # 计算当前点到光照中心距离(平面坐标系中两点之间的距离)
            distance = math.pow((centerY - j), 2) + math.pow((centerX - i), 2)
            # 获取原始图像
            B = img[i, j][0]
            G = img[i, j][1]
            R = img[i, j][2]
            if (distance < radius * radius):
                # 按照距离大小计算增强的光照值
                result1= (int)(strength[0] * (1.0 - math.sqrt(distance) / radius))
                result2 = (int)(strength[1] * (1.0 - math.sqrt(distance) / radius))
                result3 = (int)(strength[2] * (1.0 - math.sqrt(distance) / radius))
                B = img[i, j][0] + result1
                G = img[i, j][1] + result2
                R = img[i, j][2] + result3
                # 判断边界 防止越界
                B = min(255, max(0, B))
                G = min(255, max(0, G))
                R = min(255, max(0, R))
                img[i, j] = np.uint8((B, G, R))
            else:
                img[i, j] = np.uint8((B, G, R))
    return img


@get_time
def light_enhance_improve(img, strength):
    rows, cols = img.shape[:2]
    centerX = rows / 2
    centerY = cols / 2
    radius = min(centerX, centerY)
    arr_rows = np.arange(rows).reshape(rows, 1)
    arr_cols = np.arange(cols).reshape(1, cols)
    arr = 1 - np.sqrt(((arr_cols - centerY) ** 2) + ((arr_rows - centerX) ** 2)) / radius
    arr_B = np.maximum(np.int16(strength[0] * arr), 0)
    arr_G = np.maximum(np.int16(strength[1] * arr), 0)
    arr_R = np.maximum(np.int16(strength[2] * arr), 0)
    img = np.int16(img)
    img[:,:,0]+=arr_B
    img[:,:,1]+=arr_G
    img[:,:,2]+=arr_R
    img = np.maximum(img, 0)
    img = np.minimum(img, 255)
    img = np.uint8(img)
    return img


if __name__ == '__main__':
    img_path = r'./2007_002094.jpg'
    img_ori = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), cv2.IMREAD_COLOR)
    cv2.imshow('img_ori', img_ori)
    cv2.waitKey(1)

    img_tmp = img_ori.copy()

    h,w,c = img_ori.shape
    cx = w //2
    cy = h // 2
    cw = w
    ch = h

    value_B = np.random.randint(255, 510)
    value_G = np.random.randint(255, 510)
    value_R = np.random.randint(255, 510)



    img_ori[cy - ch // 2:cy + ch // 2, cx - cw // 2:cx + cw // 2, :] = light_enhance_origin(
        img_ori[cy - ch // 2:cy + ch // 2, cx - cw // 2:cx + cw // 2, :], strength=[value_B, value_G, value_R])
    cv2.imshow('light_enhance_origin', img_ori)
    cv2.waitKey(1)

    img_tmp[cy - ch // 2:cy + ch // 2, cx - cw // 2:cx + cw // 2, :] = light_enhance_improve(
        img_tmp[cy - ch // 2:cy + ch // 2, cx - cw // 2:cx + cw // 2, :], strength=[value_B, value_G, value_R])
    cv2.imshow('light_enhance_improve', img_tmp)
    cv2.waitKey()

5.效果展示

两个代码在最终效果上一致,在耗时方面提高100多倍。
在这里插入图片描述

6.总结

改进点1:并行化,提高执行速度。
改进点2:改进后的代码,会随机生成各种颜色的模拟灯光,不再是单调的白色。

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值