最近看了Retinex相关的文献,然后网上大多都是原作者的Matlab版本,然后我用python将其复现了一遍,如有错误,还请指教。
这里建议大家直接看论文,网上的博客大多都是对原文的翻译,还有很多彼此之间的无脑搬运,很影响阅读体验。相应的论文以及文件已经上传至我的资源里,感兴趣的可以免费下载。
import cv2
import numpy as np
from numpy import array,log2,exp2,log,exp,ceil,mean
# Retinex_Frankle_Mccan初代版本
class Retinex_Frankle_Mccan:
# img:彩色图像 iterations:迭代次数
def __init__(self,img):
self.height,self.width,self.channels=img.shape
# 初始化
# RR:原始图像(log化后) OP(上一次迭代的结果,初始值未RR中的最大值)
# NP:(当前图像的结果,即physical image(log化后)) IP(中间运行结果,初始值与OP相同)
self.RR=log(img+np.finfo(np.float).eps)
self.maxnuim=array([np.amax(self.RR[:,:,t]) for t in range(self.channels)])
self.OP=np.zeros_like(self.RR)+self.maxnuim
self.IP=self.OP.copy()
self.NP=np.zeros_like(self.OP)
def retinex_frankle_mccan(self,iterations=4):
# 对B,G,R三个通道进行操作
shift=int(np.exp2(round(np.log2(min(self.height,self.width)))-1))
while(abs(shift)>=1):
for i in range(iterations):
# 水平操作
self.compare_with(0,shift)
# 垂直操作
self.compare_with(shift,0)
shift=int(-shift/2)
# 反向操作将图片恢复成8位图像
return self.recover_image()
def compare_with(self,s_row,s_col):
# ratio and product 操作
if s_row+s_col>0:
self.IP[s_row:-1,s_col:-1]=self.OP[:-1-s_row,:-1-s_col]+\
self.RR[s_row:-1,s_col:-1]-self.RR[:-1-s_row,:-1-s_col]
else:
self.IP[:-1+s_row,:-1+s_col]=self.OP[-s_row:-1,-s_col:-1]+\
self.RR[:-1+s_row,:-1+s_col]-self.RR[-s_row:-1,-s_col:-1]
# reset操作
mask=self.IP>self.maxnuim
for t in range(self.channels):
self.IP[:,:,t][mask[:,:,t]]=self.maxnuim[t]
# average操作
self.NP=(self.IP+self.OP)/2
self.OP=self.NP
def recover_image(self):
# 版本一
res=np.exp(self.NP)
miniuim=np.amin(res)
maxnuim=np.amax(res)
res=(res-miniuim)/(maxnuim-miniuim)*255
return np.array(res,dtype=np.uint8)
class Retinex_Mccann99():
def __init__(self,src_img):
# 如果对尺寸有要求
# 对图像进行修改
self.src_height,self.src_width,self.channels=src_img.shape
height,width=self.ComputeSize(self.src_height,self.src_width)
self.nlayers=self.ComputeLayers(height,width)
self.nrows=int(height/(np.exp2(self.nlayers)))
self.ncols=int(width/(np.exp2(self.nlayers)))
if (self.ncols*self.nrows>25):
raise Exception("输入图像尺寸不符合要求")
img=cv2.resize(src_img,(width,height))
cv2.imwrite('new.jpg',img)
self.img_log=log(img+np.finfo(np.float).eps)
self.maxnuim=array([np.amax(self.img_log[:,:,t]) for t in range(self.channels)])
self.OP=np.zeros((self.nrows,self.ncols,self.channels))+self.maxnuim
# 如果对尺寸没有太多要求,即不要求在5*5之内,对height与width要求不大
# self.src_height,self.src_width,self.channels=src_img.shape
# height=self.ComputeNum(self.src_height)
# width=self.ComputeNum(self.src_width)
# self.nlayers=self.ComputeLayers(height,width)
# self.nrows=int(height/(np.exp2(self.nlayers)))
# self.ncols=int(width/(np.exp2(self.nlayers)))
# img=cv2.resize(src_img,(width,height))
# self.img_log=log(img+np.finfo(np.float).eps)
# self.maxnuim=array([np.amax(self.img_log[:,:,t]) for t in range(self.channels)])
# self.OP=np.zeros((self.nrows,self.ncols,self.channels))+self.maxnuim
def retinex_mccann99(self,niterations=4):
for t in range(self.nlayers+1):
self.RR=self.ImageDownResolution(int(exp2(self.nlayers-t)))
# 对图像进行金字塔操作
self.OPE=np.array([np.pad(self.OP[:,:,t],((1),(1))) for t in range(self.channels)])
self.RRE=np.array([np.pad(self.RR[:,:,t],((1),(1))) for t in range(self.channels)])
self.OPE=np.transpose(self.OPE,(1,2,0))
self.RRE=np.transpose(self.RRE,(1,2,0))
for iter in range(niterations):
self.CompareWithNeighbor(-1,0)
self.CompareWithNeighbor(-1,1)
self.CompareWithNeighbor(0,1)
self.CompareWithNeighbor(1,1)
self.CompareWithNeighbor(1,0)
self.CompareWithNeighbor(1,-1)
self.CompareWithNeighbor(0,-1)
self.CompareWithNeighbor(-1,-1)
self.NP=self.OPE[1:-1,1:-1]
# 行扩充
self.OP=self.NP[[int(p/2) for p in range(self.nrows*2)],:]
# 列扩充
self.OP=self.OP[:,[int(p/2) for p in range(self.ncols*2)]]
self.nrows*=2
self.ncols*=2
return cv2.resize(self.recover_image(),(self.src_width,self.src_height))
def CompareWithNeighbor(self,dif_row,dif_col):
# -1 相当于 self.nrows/self.ncols + 1
IP=self.OPE[1+dif_row:self.nrows+1+dif_row,1+dif_col:self.ncols+1+dif_col]+self.RRE[1:-1,1:-1]-self.RRE[1+dif_row:self.nrows+1+dif_row,1+dif_col:self.ncols+1+dif_col]
mask=IP>self.maxnuim
for t in range(self.channels):
IP[:,:,t][mask[:,:,t]]=self.maxnuim[t]
if (dif_col == -1):
IP[:,0]=self.OPE[1:-1,1]
if (dif_col == 1):
IP[:,-1]=self.OPE[1:-1,-2]
if (dif_row == -1):
IP[0,:]=self.OPE[1,1:-1]
if (dif_row == 1):
IP[-1,:]=self.OPE[-2,1:-1]
self.OPE[1:-1,1:-1]=(self.OPE[1:-1,1:-1]+IP)/2
def recover_image(self):
# 版本一
res=np.exp(self.NP)
miniuim=np.amin(res)
maxnuim=np.amax(res)
res=(res-miniuim)/(maxnuim-miniuim)*255
return np.array(res,dtype=np.uint8)
def ImageDownResolution(self,blocksize):
h,w=self.img_log.shape[:2]
self.RR=np.zeros((int(h/blocksize),int(w/blocksize),self.channels))
for i in range(0,int(h/blocksize)):
for j in range(0,int(w/blocksize)):
for t in range(self.channels):
self.RR[i,j,t]=np.mean(self.img_log[i*blocksize:(i+1)*blocksize,j*blocksize:(j+1)*blocksize,t])
return self.RR
def ComputeLayers(self,nrows,ncols):
power=np.exp2(round(np.log2(np.gcd(nrows,ncols))))
while(power>1 and (nrows%power!=0 or ncols%power!=0)):
power/=2
return int(np.log2(power))
def ComputeSize(self,height,width,module_length=5):
maxRatio=-1
srcRatio=width/height
for i in range(1,module_length+1):
for j in range(1,module_length+1):
ratio=j/i
tmp=min(ratio,srcRatio)/max(ratio,srcRatio)
if tmp>maxRatio:
maxRatio=tmp
record=[i,j]
power=max(round(log2(height/record[0])),round(log2(width/record[1])))
rh,rw=exp2(power)*record[0],exp2(power)*record[1]
return int(rh),int(rw)
def ComputeNum(self,num,module_length=5):
minium=np.Inf
for i in range(1,module_length+1):
power=round(log2(num/i))
tmp=exp2(power)*i
if abs(tmp-num)<minium:
minium=abs(tmp-num)
res=tmp
return int(res)
if __name__=='__main__':
img=cv2.imread('1.jpg',cv2.IMREAD_COLOR)
# 测试Retinex_Frankle_Mccan
# retinex_fm=Retinex_Frankle_Mccan(img)
# dst=retinex_fm.retinex_frankle_mccan(iterations=4)
# 测试Retinex_Mccann99
retinex_fm=Retinex_Mccann99(img)
dst=retinex_fm.retinex_mccann99()
cv2.imshow('src',img)
cv2.imshow('dst',dst)
cv2.waitKey(0)