批量水下图像处理白平衡+融合

本文介绍了一个基于Python的类,用于处理水下图像的色彩补偿、白平衡调整、伽马校正和融合,包括简单白平衡、灰度世界法、锐化、权重计算以及多级金字塔融合方法。
摘要由CSDN通过智能技术生成

修改后的如下,不太对啊亲们,色彩补偿白平衡这块不对

from tkinter import NO
import os
import numpy as np
import cv2


# 封装成类
class Color_Balance_and_Fusion_for_Underwater_Image_Enhancement():
    def __init__(self):
        pass

    # 读图片
    def read_img_by_cv2(self, img_path):
        img = cv2.imread(img_path)
        # img=cv2.resize(img,(800,600))
        return img

    # 论文所提出的白平衡
    def simple_color_balance(self, img, alpha, blur_need=None):
        R = img[:, :, 2]
        G = img[:, :, 1]
        B = img[:, :, 0]
        # 三颜色通道均值再归一化,对应 I¯r I¯g I¯b
        Irm = np.mean(R) / 256.0
        Igm = np.mean(G) / 256.0
        Ibm = np.mean(B) / 256.0

        Irc = R + alpha * (Igm - Irm) * (1 - Irm) * G  # 补偿红色通道
        Irc = np.array(Irc.reshape(G.shape), np.uint8)
        if blur_need:
            Ibc = B + alpha * (Igm - Ibm) * (1 - Ibm) * G  # 补偿蓝色通道
            Ibc = np.array(Ibc.reshape(G.shape), np.uint8)
            img = cv2.merge([Ibc, G, Irc])
        else:
            img = cv2.merge([B, G, Irc])

        img = np.clip(img, 0, 255)
        img = img.astype(np.uint8)
        return img

    # 白平衡
    def gray_world(self, img):
        # 将图像转换为浮点格式
        img_float = img.astype(float)
        # 计算图像的各通道平均值
        avg_b = np.mean(img_float[:, :, 0])
        avg_g = np.mean(img_float[:, :, 1])
        avg_r = np.mean(img_float[:, :, 2])

        # 计算各通道的增益
        gain_b = avg_g / avg_b
        gain_r = avg_g / avg_r

        # 应用增益来进行白平衡
        balanced = cv2.merge([img_float[:, :, 0] * gain_b, img_float[:, :, 1], img_float[:, :, 2] * gain_r])

        # 将结果限制在0到255的范围内
        balanced = np.clip(balanced, 0, 255).astype(np.uint8)
        return balanced

    # 伽马校正
    def gamma_correct(self, img, gamma=1.2):
        # 进行伽马校正
        gamma_corrected = np.power(img / 255.0, gamma) * 255.0
        gamma_corrected = np.clip(gamma_corrected, 0, 255).astype(np.uint8)
        return gamma_corrected

    # 论文提出的锐化方法
    def sharp(self, image):
        # Convert the image to floating point format
        image = image.astype(np.float32)

        # Apply Gaussian smoothing
        im_blur = cv2.GaussianBlur(image, (7, 7), 0)

        # Calculate unsharp mask
        unsharp_mask = image - im_blur

        # Perform histogram stretching
        stretch_im = cv2.normalize(unsharp_mask, None, 0, 255, cv2.NORM_MINMAX)  # type: ignore

        # Combine original and stretched images
        result = (image + stretch_im) / 2

        # Convert result back to uint8 format
        result = np.clip(result, 0, 255).astype(np.uint8)

        return result

    # Laplacian contrast weight (WL)
    def laplacian_weight(self, img):
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        w = cv2.Laplacian(img, cv2.CV_64F)
        w = cv2.convertScaleAbs(w)
        return w

    # Saliency weight (WS)
    def saliency_weight(self, img):
        gfrgb = cv2.GaussianBlur(img, (3, 3), 0)
        lab = cv2.cvtColor(gfrgb, cv2.COLOR_RGB2LAB)
        l = np.double(lab[:, :, 0])
        a = np.double(lab[:, :, 1])
        b = np.double(lab[:, :, 2])
        lm = np.mean(np.mean(l))
        am = np.mean(np.mean(a))
        bm = np.mean(np.mean(b))
        w = np.square(l - lm) + np.square(a - am) + np.square((b - bm))
        return w

    # Saturation weight (WSat)
    def saturation_weight(self, img):
        b, g, r = cv2.split(img)
        lum = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        w = np.sqrt((1 / 3) * ((r - lum) ** 2 + (g - lum) ** 2 + (b - lum) ** 2))
        return w

    # 论文中计算权重的方法
    def weight_maps(self, img):
        img.astype(np.float64)
        WL = self.laplacian_weight(img)
        WS = self.saliency_weight(img)
        WSat = self.saturation_weight(img)
        return WL, WS, WSat

    def gaussian_pyramid(self, image, levels):
        pyramid = [image]
        for i in range(levels - 1):
            image = cv2.pyrDown(image)
            pyramid.append(image)

        return pyramid

    def laplacian_pyramid(self, image, levels):
        gaussian_pyramid_list = self.gaussian_pyramid(image, levels)
        laplacian_pyramid = [gaussian_pyramid_list[-1]]

        for i in range(levels - 1, 0, -1):
            size = (gaussian_pyramid_list[i - 1].shape[1], gaussian_pyramid_list[i - 1].shape[0])
            expanded = cv2.pyrUp(gaussian_pyramid_list[i], dstsize=size)
            laplacian = cv2.subtract(gaussian_pyramid_list[i - 1], expanded)
            laplacian_pyramid.insert(0, laplacian)
        return laplacian_pyramid

    # 金字塔重建
    def pyramid_reconstruct(self, pyramid):
        level = len(pyramid)

        for i in range(level - 1, 0, -1):
            m, n, c = pyramid[i - 1].shape
            pyramid[i - 1] += pyramid[i - 1] + cv2.resize(pyramid[i], (n, m))
        output = pyramid[0]
        output = np.clip(output, 0, 255).astype(np.uint8)
        return output

    # 增强
    def Enhance(self, img, mode='naive', gamma=1.2, blur_need=False, level=3):

        img_balenced_color = self.simple_color_balance(img, alpha=1, blur_need=blur_need)

        img_balenced_white = self.gray_world(img_balenced_color)

        img_gamma_corrected = self.gamma_correct(img_balenced_white, gamma=gamma)
        img_sharpen = self.sharp(img_balenced_white)

        (WL1, WS1, WSat1) = self.weight_maps(img_gamma_corrected)
        (WL2, WS2, WSat2) = self.weight_maps(img_sharpen)
        W1 = (WL1 + WS1 + WSat1 + 0.1) / (WL1 + WS1 + WSat1 + WL2 + WS2 + WSat2 + 0.2)
        W2 = (WL2 + WS2 + WSat2 + 0.1) / (WL1 + WS1 + WSat1 + WL2 + WS2 + WSat2 + 0.2)

        W1 = np.repeat(W1[:, :, np.newaxis], 3, axis=-1)
        W2 = np.repeat(W2[:, :, np.newaxis], 3, axis=-1)

        if mode == 'naive':
            result = np.multiply(W1, img_gamma_corrected) + np.multiply(W2, img_sharpen)
            return result.astype(np.uint8)
        elif mode == 'multi':
            # 构造权值的高斯金字塔

            W1 = self.gaussian_pyramid(W1, level)
            W2 = self.gaussian_pyramid(W2, level)

            # 构造输入图像的拉普拉斯金字塔
            input1_laplacian_pyramid = self.laplacian_pyramid(img_gamma_corrected, level)
            input2_laplacian_pyramid = self.laplacian_pyramid(img_sharpen, level)

            # 权值融合
            result = []
            for i in range(level):
                fuse = np.multiply(W1[i], input1_laplacian_pyramid[i]) + np.multiply(W2[i], input2_laplacian_pyramid[i])
                result.append(fuse)

            # 图像重建
            reconstructed_image = self.pyramid_reconstruct(result)
            return reconstructed_image.astype(np.uint8)

        else:
            print('Please choose mode ("naive" or "multi")')

    def process_images_in_folder(self, input_folder, output_folder, mode='naive', gamma=1.2, blur_need=False, level=3):
        os.makedirs(output_folder, exist_ok=True)

        image_files = [f for f in os.listdir(input_folder) if f.endswith(('.png', '.jpg', '.jpeg'))]

        for image_file in image_files:
            image_path = os.path.join(input_folder, image_file)
            img = self.read_img_by_cv2(image_path)

            if mode == 'multi':
                img_enhanced = self.Enhance(img, mode='multi', gamma=gamma, blur_need=blur_need, level=level)
            else:
                print('Invalid mode.')
                return

            output_path = os.path.join(output_folder, f"enhanced_{image_file}")
            cv2.imwrite(output_path, img_enhanced)

        print("Image enhancement completed.")



# Example
# if __name__ == '__main__':
#     enhancer = Color_Balance_and_Fusion_for_Underwater_Image_Enhancement()
#     img_path = r'E:\Pycharm\PycharmProjects\ImageEn\000055.jpg'
#     img = cv2.imread(img_path)
#     img = cv2.resize(img, (600, 450))
#     img_enhanced_naive = enhancer.Enhance(img, mode='naive', gamma=2, blur_need=False)
#     img_enhanced_multi = enhancer.Enhance(img, mode='multi', gamma=2, blur_need=False, level=3)
#     cv2.imshow('result', np.hstack((img, img_enhanced_naive, img_enhanced_multi)))  # type: ignore
#     cv2.waitKey()
#     cv2.destroyAllWindows()





if __name__ == '__main__':
    enhancer = Color_Balance_and_Fusion_for_Underwater_Image_Enhancement()
    input_folder = r'E:\Pycharm\PycharmProjects\ImageEn\data'
    output_folder = r'E:\Pycharm\PycharmProjects\ImageEn\output'

    enhancer.process_images_in_folder(input_folder, output_folder, mode='multi', gamma=2, blur_need=False, level=3)

借鉴【论文复现】(python):《Color Balance and Fusion for Underwater Image Enhancement》-CSDN博客

  • 10
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值