Frankle_Mccan_Retinex与Mccann99_Retinex python版本

最近看了Retinex相关的文献,然后网上大多都是原作者的Matlab版本,然后我用python将其复现了一遍,如有错误,还请指教。
这里建议大家直接看论文,网上的博客大多都是对原文的翻译,还有很多彼此之间的无脑搬运,很影响阅读体验。相应的论文以及文件已经上传至我的资源里,感兴趣的可以免费下载。

import cv2
import numpy as np
from numpy import array,log2,exp2,log,exp,ceil,mean 

# Retinex_Frankle_Mccan初代版本

class Retinex_Frankle_Mccan:
    #   img:彩色图像    iterations:迭代次数
    def __init__(self,img):
        self.height,self.width,self.channels=img.shape
        # 初始化
        # RR:原始图像(log化后) OP(上一次迭代的结果,初始值未RR中的最大值)
        # NP:(当前图像的结果,即physical image(log化后)) IP(中间运行结果,初始值与OP相同)
        self.RR=log(img+np.finfo(np.float).eps)
        self.maxnuim=array([np.amax(self.RR[:,:,t]) for t in range(self.channels)])
        self.OP=np.zeros_like(self.RR)+self.maxnuim
        self.IP=self.OP.copy()
        self.NP=np.zeros_like(self.OP)
    
    def retinex_frankle_mccan(self,iterations=4):
        # 对B,G,R三个通道进行操作
        shift=int(np.exp2(round(np.log2(min(self.height,self.width)))-1))
        while(abs(shift)>=1):
            for i in range(iterations):
                # 水平操作
                self.compare_with(0,shift)
                # 垂直操作
                self.compare_with(shift,0)               
            shift=int(-shift/2)
        # 反向操作将图片恢复成8位图像
        return self.recover_image()
                    
    def compare_with(self,s_row,s_col):
        # ratio and product 操作
        if s_row+s_col>0:
            self.IP[s_row:-1,s_col:-1]=self.OP[:-1-s_row,:-1-s_col]+\
                self.RR[s_row:-1,s_col:-1]-self.RR[:-1-s_row,:-1-s_col]
        else:
            self.IP[:-1+s_row,:-1+s_col]=self.OP[-s_row:-1,-s_col:-1]+\
                self.RR[:-1+s_row,:-1+s_col]-self.RR[-s_row:-1,-s_col:-1]
        # reset操作
        mask=self.IP>self.maxnuim
        for t in range(self.channels):
            self.IP[:,:,t][mask[:,:,t]]=self.maxnuim[t]
        # average操作
        self.NP=(self.IP+self.OP)/2
        self.OP=self.NP 

    def recover_image(self):
        # 版本一
        res=np.exp(self.NP)
        miniuim=np.amin(res)
        maxnuim=np.amax(res)
        res=(res-miniuim)/(maxnuim-miniuim)*255
        return np.array(res,dtype=np.uint8)    

class Retinex_Mccann99():
    def __init__(self,src_img):
        # 如果对尺寸有要求
        # 对图像进行修改
        self.src_height,self.src_width,self.channels=src_img.shape
        height,width=self.ComputeSize(self.src_height,self.src_width)
        self.nlayers=self.ComputeLayers(height,width)
        self.nrows=int(height/(np.exp2(self.nlayers)))
        self.ncols=int(width/(np.exp2(self.nlayers)))
        if (self.ncols*self.nrows>25):
            raise Exception("输入图像尺寸不符合要求")
        img=cv2.resize(src_img,(width,height))
        cv2.imwrite('new.jpg',img)
        self.img_log=log(img+np.finfo(np.float).eps)
        self.maxnuim=array([np.amax(self.img_log[:,:,t]) for t in range(self.channels)])
        self.OP=np.zeros((self.nrows,self.ncols,self.channels))+self.maxnuim

        # 如果对尺寸没有太多要求,即不要求在5*5之内,对height与width要求不大
        # self.src_height,self.src_width,self.channels=src_img.shape
        # height=self.ComputeNum(self.src_height)
        # width=self.ComputeNum(self.src_width)
        # self.nlayers=self.ComputeLayers(height,width)
        # self.nrows=int(height/(np.exp2(self.nlayers)))
        # self.ncols=int(width/(np.exp2(self.nlayers)))
        # img=cv2.resize(src_img,(width,height))
        # self.img_log=log(img+np.finfo(np.float).eps)
        # self.maxnuim=array([np.amax(self.img_log[:,:,t]) for t in range(self.channels)])
        # self.OP=np.zeros((self.nrows,self.ncols,self.channels))+self.maxnuim



    def retinex_mccann99(self,niterations=4):
        for t in range(self.nlayers+1):
            self.RR=self.ImageDownResolution(int(exp2(self.nlayers-t)))
            # 对图像进行金字塔操作
            self.OPE=np.array([np.pad(self.OP[:,:,t],((1),(1))) for t in range(self.channels)])
            self.RRE=np.array([np.pad(self.RR[:,:,t],((1),(1))) for t in range(self.channels)])
            self.OPE=np.transpose(self.OPE,(1,2,0))
            self.RRE=np.transpose(self.RRE,(1,2,0))
            for iter in range(niterations):
                self.CompareWithNeighbor(-1,0)
                self.CompareWithNeighbor(-1,1)
                self.CompareWithNeighbor(0,1)
                self.CompareWithNeighbor(1,1)
                self.CompareWithNeighbor(1,0)
                self.CompareWithNeighbor(1,-1)
                self.CompareWithNeighbor(0,-1)
                self.CompareWithNeighbor(-1,-1)

            self.NP=self.OPE[1:-1,1:-1]
            # 行扩充
            self.OP=self.NP[[int(p/2) for p in range(self.nrows*2)],:]
            # 列扩充
            self.OP=self.OP[:,[int(p/2) for p in range(self.ncols*2)]]
            self.nrows*=2
            self.ncols*=2
        return cv2.resize(self.recover_image(),(self.src_width,self.src_height))
        
    def CompareWithNeighbor(self,dif_row,dif_col):
        # -1 相当于 self.nrows/self.ncols + 1
        IP=self.OPE[1+dif_row:self.nrows+1+dif_row,1+dif_col:self.ncols+1+dif_col]+self.RRE[1:-1,1:-1]-self.RRE[1+dif_row:self.nrows+1+dif_row,1+dif_col:self.ncols+1+dif_col]
        mask=IP>self.maxnuim
        for t in range(self.channels):
            IP[:,:,t][mask[:,:,t]]=self.maxnuim[t]
        if (dif_col == -1):
            IP[:,0]=self.OPE[1:-1,1]
        if (dif_col == 1): 
            IP[:,-1]=self.OPE[1:-1,-2] 
        if (dif_row == -1):
            IP[0,:]=self.OPE[1,1:-1]
        if (dif_row == 1):
            IP[-1,:]=self.OPE[-2,1:-1]
        self.OPE[1:-1,1:-1]=(self.OPE[1:-1,1:-1]+IP)/2

    def recover_image(self):
        # 版本一
        res=np.exp(self.NP)
        miniuim=np.amin(res)
        maxnuim=np.amax(res)
        res=(res-miniuim)/(maxnuim-miniuim)*255
        return np.array(res,dtype=np.uint8) 

    def ImageDownResolution(self,blocksize):        
        h,w=self.img_log.shape[:2]
        self.RR=np.zeros((int(h/blocksize),int(w/blocksize),self.channels))
        for i in range(0,int(h/blocksize)):
            for j in range(0,int(w/blocksize)):
                for t in range(self.channels):
                    self.RR[i,j,t]=np.mean(self.img_log[i*blocksize:(i+1)*blocksize,j*blocksize:(j+1)*blocksize,t])
        return self.RR
                
    def ComputeLayers(self,nrows,ncols):
        power=np.exp2(round(np.log2(np.gcd(nrows,ncols))))
        while(power>1 and (nrows%power!=0 or ncols%power!=0)):
            power/=2
        return int(np.log2(power))

    def ComputeSize(self,height,width,module_length=5):
        maxRatio=-1
        srcRatio=width/height
        for i in range(1,module_length+1):
            for j in range(1,module_length+1):
                ratio=j/i
                tmp=min(ratio,srcRatio)/max(ratio,srcRatio)
                if tmp>maxRatio:
                    maxRatio=tmp
                    record=[i,j]
        power=max(round(log2(height/record[0])),round(log2(width/record[1])))
        rh,rw=exp2(power)*record[0],exp2(power)*record[1]
        return int(rh),int(rw)

    def ComputeNum(self,num,module_length=5):
        minium=np.Inf
        for i in range(1,module_length+1):
            power=round(log2(num/i))
            tmp=exp2(power)*i
            if abs(tmp-num)<minium:
                minium=abs(tmp-num)
                res=tmp
        return int(res)


    


if __name__=='__main__':
    img=cv2.imread('1.jpg',cv2.IMREAD_COLOR)
    # 测试Retinex_Frankle_Mccan
    # retinex_fm=Retinex_Frankle_Mccan(img)
    # dst=retinex_fm.retinex_frankle_mccan(iterations=4)
    # 测试Retinex_Mccann99
    retinex_fm=Retinex_Mccann99(img)
    dst=retinex_fm.retinex_mccann99()
    cv2.imshow('src',img)
    cv2.imshow('dst',dst)
    cv2.waitKey(0)

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值