论文:On Finding Gray Pixels
主要还是利用光源不变形,作者进行了一些分析和变形。
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
from numpy import Inf
from scipy import ndimage
from mpmath import eps
import scipy.io as scio
def deriv_gauss(img, sigma):
Gaussoff = 0.000001
pw = np.array(range(50)).astype(np.int32)
ssq = sigma ** 2
width = np.sum(np.exp(-(pw * pw) / (2 * ssq)) > Gaussoff) - 1
xx = np.arange(2 * width + 1) - width
yy = np.arange(2 * width + 1) - width
x, y = np.meshgrid(xx, yy)
dg2d = -x * np.exp(-(x * x + y * y) / (2 * ssq)) / (np.pi * ssq)
dg2d = np.array(dg2d)
ax = ndimage.convolve(img, dg2d, mode='nearest')
ay = ndimage.convolve(img, dg2d.T, mode='nearest')
mag = np.sqrt(ax * ax + ay * ay)
return mag
def normr(data):
data1 = data / np.sqrt(np.sum(data * data, 1)).reshape(-1, 1)
return data1
def normc(data):
data1 = data / np.sqrt(np.sum(data * data, 0)).reshape(-1, 1)
return data1
def cal_angle(light, ref):
light = np.reshape(light, (-1, 3))
ref = np.reshape(ref, (-1, 3))
cos_angle = np.sum(light * ref, axis=1) / (
np.sqrt(np.sum(np.power(light, 2), 1)) * np.sqrt(np.sum(np.power(ref, 2), 1)))
angle = np.arccos(cos_angle)
angle = angle * 180 / np.pi
return angle
def find_gray(im, mask=[]):
h, w, n = im.shape
# mask saturated and dark pixels
if mask == []:
mask = np.zeros((h, w))
high_thr = 0.95
low_thr = 0.0315
mask[np.max(im, axis=2) >= high_thr] = 1
mask[np.sum(im, axis=2) <= low_thr] = 1
nums_gray = int(np.floor(h * w // 100))
print('nums_gray : ', nums_gray, np.sum(mask))
# de-noise and replace 0 element with eps
r = im[..., 0]
g = im[..., 1]
b = im[..., 2]
siz = 7
hh = np.ones([siz, siz]) / (siz * siz)
r = ndimage.convolve(r, hh, mode='nearest')
g = ndimage.convolve(g, hh, mode='nearest')
b = ndimage.convolve(b, hh, mode='nearest')
mask = np.logical_or(np.logical_or(mask, r == 0), np.logical_or(g == 0, b == 0))
r[r == 0] = eps
g[g == 0] = eps
b[b == 0] = eps
normrgb = r + g + b
# mask low contrast pixels
sigma = 0.5
delta_thr = 1e-4
dr = deriv_gauss(r, sigma)
dg = deriv_gauss(g, sigma)
db = deriv_gauss(b, sigma)
mask = np.logical_or(mask, np.logical_and(dr < delta_thr, dg < delta_thr, db < delta_thr))
print(' mask nums:', np.sum(mask))
# ill estimate
log_r = np.log(r) - np.log(normrgb)
log_b = np.log(b) - np.log(normrgb)
delta_log_r = deriv_gauss(log_r, sigma)
delta_log_b = deriv_gauss(log_b, sigma)
mask = np.logical_or(mask, (delta_log_r == Inf), (delta_log_b == Inf))
data = np.hstack((delta_log_r.reshape(-1, 1), delta_log_b.reshape(-1, 1)))
grayness = np.sqrt(np.sum(data*data, 1)).reshape(h, w)
grayness[mask == 1] = np.max(grayness)
grayness = ndimage.convolve(grayness, hh, mode='nearest')
# selected gray pixels
grayness_sorted = np.sort(grayness.copy().reshape(-1))
grayness_thr = grayness_sorted[nums_gray]
mask_sel = grayness < grayness_thr
im_choose = im * mask_sel[..., None]
ill_est = np.sum(im_choose, axis=(0, 1)).reshape(-1, 3)
ill_est = ill_est / np.sqrt(np.sum(ill_est*ill_est, 1)).reshape(-1, 1)
return ill_est
if __name__ == "__main__":
# single pic
filename = r'G:\ffcc-master_20201108\ffcc-master\data\shi_gehler\preprocessed\GehlerShi\000002.png'
imbgr = cv2.imread(filename)
im = imbgr[:, :, ::-1]
im = im / 255
gt_file = filename[:-4] + '.txt'
ill_gt = np.loadtxt(gt_file)
# if load mat form
# filename = 'G:\github\code\Mean-shifted-Gray-Pixel-master\Mean-shifted-Gray-Pixel-master\exampleimg.mat'
# data = scio.loadmat(filename)
# im = data['input_im']
# mask = data['mask']
# gt = data['gt']
h, w, n = im.shape
mask = np.zeros((h, w)).astype(bool)
ill_est = find_gray(im, mask)
angle = cal_angle(ill_est, ill_gt)
print(ill_est, ill_gt, angle)
# dir
angles = []
dir = r'G:\ffcc-master_20201108\ffcc-master\data\shi_gehler\preprocessed\GehlerShi'
filesets = os.listdir(dir)
for file in filesets:
if file.endswith('.png'):
filename = os.path.join(dir, file)
im = cv2.imread(filename)
im = im[:, :, ::-1] / 255
file_gt = filename[:-4] + '.txt'
gt = np.loadtxt(file_gt)
ill = find_gray(im)
angle = cal_angle(ill, gt)
print(ill, gt, angle)
angles.append(angle)
angles = np.array(angles)
np.savetxt('wb_find_gray.txt', angles, delimiter=' ', fmt='%.7f')
print(angles, np.mean(angles))