采取的模板是5*5规格大小
# 同一张图片里,寻找相似的5*5的batch块补全图像缺失的部分,按像素补全
# 模板匹配采取误差平方和最小ssd算法
import cv2
import numpy as np
def ssd_distance_two_batch(batch1,batch2,pic,flag_to_deal):
# batch1,batch2是一个列表,里面有25个点
# pic是图片的array
# flag是参考的是否需要更新的array
sum = 0.0
count = 0
# batch是如下格式【(x1,y1),(x2,y2)----25个点】
# batch1代表待填补的区域块
# batch2代表待匹配的图像上的其他区域快
for x in range(len(batch1)):
# print('right')
tempx1,tempy1 = batch1[x]
tempx2, tempy2 = batch2[x]
if flag_to_deal[tempx1,tempy1] == 0 :# 表示该区域本来就存在原始的数值,或是更新后的新数值,可以拿来计算最近邻
sum = sum + (pic[tempx1,tempy1] - pic[tempx2,tempy2])**2 #ssd距离衡量
count = count + 1
batch2_middle_x,batch2_middle_y = batch2[len(batch2)//2]
return sum *1.0/count ,pic[batch2_middle_x,batch2_middle_y]
# 返回值是两个batch的距离,以及第二个batch的中心点的数值(可以拿来填充第一个batch)
def get_flag_point(point, flag):
x,y = point # x,和y超过范围了??
return flag[x,y]
def get_importance(little_batch,picture_array, flag_point):
# 在一个小batch里面,求已知点的区域中,计算方差
# little_batch 是5*5的小区域,[(x1,y1),(x2,y2)...]
# flag_point代表先前的flag矩阵,0代表可信值
count_point_has_value = 0
sum_point_has_value = 0.0
for pointx, pointy in little_batch:
flag_number = get_flag_point(point=(pointx,pointy) , flag=flag_point)
print('到这都是i对的')
if flag_number==0 :
# 该点已经有值了,把他加入计算到求importance里面
sum_point_has_value = sum_point_has_value + picture_array[pointx,pointy]
count_point_has_value = count_point_has_value + 1
# 容易得到count_point_value=0 无法当作除数
importance_get = 0.0
if count_point_has_value==0:
# 若是该batch处于内部,不补充该位置,处于边缘的进行补充
importance_get=0.0
if count_point_has_value>0:
# 该batch处于边缘上:计算方差等等操作
avg_value = sum_point_has_value /count_point_has_value
for pointx, pointy in little_batch:
if get_flag_point((pointx,pointy) , flag=flag_point)==0 :
# 该点已经有值了,把他加入计算到求方差里面
importance_get = importance_get + (avg_value - picture_array[pointx,pointy])**2
return importance_get
def find_batch(pic_array,flag):
# 找到哪个batch是亟待修补的
height, width = pic_array.shape
# 计算要修补的batch重要程度
importance = 0
# 中心点坐标是
recordX=0;recordY=0;
for iii in range(0 + 2, height - 2):
for jjjj in range(0 + 2, width - 2):
getFlag = get_flag_point(point=(iii,jjjj) ,flag= flag )
if getFlag == 1:
# 该点处于待修补区域,求该点的待修补紧急程度,原则是待修补的batch里面,已知值的点求平均,
# 然后对每个点做差求绝对值,再求和,其中最大的就是高频区域(类似是方差最大的原则当作高频)
batch_of_point = get_one_batch( pointXY = (iii,jjjj) ) # 得到 i,j 为中心的 5*5 小 batch
if get_importance(batch_of_point, picture_array= pic_array , flag_point = flag)>=importance:
importance = get_importance(batch_of_point, picture_array= pic_array , flag_point = flag)
batch_record = batch_of_point
recordX=iii;recordY=jjjj;
return batch_record ,recordX,recordY# 找到了亟需待补全的batch, 和要记录的坐标x,y
def find_point_to_inpaint(pic , flag):
# 在待修补图像上寻找最需要修补的像素点位置
# 寻找原则是高频信息
point_sets= []
for i in range(height):
for j in range(width):
if flag[i, j] == 1: point_sets.append((i,j))
# point_sets包含所有需要修补的点
for x,y in point_sets:
# 选择一个点,找到周围5*5-1=24个点
point_side = []
# 标注24个点是否是已经填充区域
for i in range(24):
point_side.append((x - 2, y - 2));point_side.append((x - 2, y - 1 ));point_side.append((x - 2, y ));point_side.append((x - 2 , y + 1));point_side.append((x - 2 , y + 2));
point_side.append((x - 1 , y - 2));point_side.append((x - 1 , y - 1 ));point_side.append((x - 1 , y )); point_side.append((x - 1 , y + 1)); point_side.append((x - 1 , y + 2));
point_side.append((x , y - 2));point_side.append((x , y - 1 ));point_side.append((x , y )); point_side.append((x , y + 1));
point_side.append((x + 1 , y - 2));point_side.append((x + 1 , y - 1 ));point_side.append((x + 1 , y )); point_side.append((x + 1 , y + 1)); point_side.append((x + 1 , y + 2));
point_side.append((x + 2 , y - 2));point_side.append((x + 2 , y - 1 ));point_side.append((x + 2 , y )); point_side.append((x + 2 , y + 1)); point_side.append((x + 2 , y + 2));
# 该点是否为已经填充的区域
flag_of_side = []
for dotx,doty in point_side:
if get_flag_point((dotx,doty),flag) == 0: flag_of_side.append( 1 );
else: flag_of_side.append(0);
print(flag_of_side) # flagside表示周围24个点,其中哪个点是本来就存在值的,标为1
def get_one_batch(pointXY):
i,j = pointXY
temp_batch = []
for temp_point in range(1):
tempPoint = (i - 2, j - 2); temp_batch.append(tempPoint);
tempPoint = (i - 2, j - 1); temp_batch.append(tempPoint);
tempPoint = (i - 2, j); temp_batch.append(tempPoint);
tempPoint = (i - 2, j+1); temp_batch.append(tempPoint);
tempPoint = (i - 2, j+2); temp_batch.append(tempPoint);
tempPoint = (i - 1, j - 2);
temp_batch.append(tempPoint);
tempPoint = (i - 1, j - 1);
temp_batch.append(tempPoint);
tempPoint = (i - 1, j);
temp_batch.append(tempPoint);
tempPoint = (i - 1, j + 1);
temp_batch.append(tempPoint);
tempPoint = (i - 1, j + 2);
temp_batch.append(tempPoint);
tempPoint = (i , j - 2);
temp_batch.append(tempPoint);
tempPoint = (i , j - 1);
temp_batch.append(tempPoint);
tempPoint = (i , j);
temp_batch.append(tempPoint);
tempPoint = (i , j + 1);
temp_batch.append(tempPoint);
tempPoint = (i , j + 2);
temp_batch.append(tempPoint);
tempPoint = (i + 1, j - 2);
temp_batch.append(tempPoint);
tempPoint = (i + 1, j - 1);
temp_batch.append(tempPoint);
tempPoint = (i + 1, j);
temp_batch.append(tempPoint);
tempPoint = (i + 1, j + 1);
temp_batch.append(tempPoint);
tempPoint = (i + 1, j + 2);
temp_batch.append(tempPoint);
tempPoint = (i + 2, j - 2);
temp_batch.append(tempPoint);
tempPoint = (i + 2, j - 1);
temp_batch.append(tempPoint);
tempPoint = (i + 2, j);
temp_batch.append(tempPoint);
tempPoint = (i + 2, j + 1);
temp_batch.append(tempPoint);
tempPoint = (i + 2, j + 2);
temp_batch.append(tempPoint);
# temp_batch 是[(x1,y1) ,(x2,y2)----25个点]
return temp_batch
def get_all_batch(pic,flag):
# 一个batch的尺寸是5*5
all_batch = [] # 符合条件的所有batch,条件就是该batch内不存在未知像素点
height,width =pic.shape
for i in range(0+2,height-2):
for j in range(0+2,width-2):
temp_batch = []
flag_is_batch = 1
# 对区域进行判断,有未知点的区域都不能算作可以拿来匹配的区域
for panduan in range(1):
################################
tempPoint = (i-2,j-2);temp_batch.append(tempPoint);
if get_flag_point(tempPoint ,flag=flag)==1 : flag_is_batch=0
tempPoint = (i-2,j-1);temp_batch.append(tempPoint);
if get_flag_point(tempPoint ,flag=flag)==1 : flag_is_batch=0
tempPoint = (i-2,j );temp_batch.append(tempPoint);
if get_flag_point(tempPoint ,flag=flag)==1 : flag_is_batch=0
tempPoint = (i-2,j +1);temp_batch.append(tempPoint);
if get_flag_point(tempPoint ,flag=flag)==1 : flag_is_batch=0
tempPoint = (i-2,j+2 );temp_batch.append(tempPoint);
if get_flag_point(tempPoint ,flag=flag)==1 : flag_is_batch=0
################################
tempPoint = (i - 1, j - 2);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i - 1, j - 1);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i - 1, j);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i - 1, j + 1);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i - 1, j + 2);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
##################
tempPoint = (i, j - 2);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i, j - 1);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i, j );temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i, j + 1);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i, j + 2);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
#########
tempPoint = (i + 1, j - 2);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i + 1, j - 1);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i + 1, j);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i + 1, j + 1);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i + 1, j + 2);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
############
tempPoint = (i + 2, j - 2);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i + 2, j - 1);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i + 2, j);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i + 2, j + 1);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
tempPoint = (i + 2, j + 2);temp_batch.append(tempPoint);
if get_flag_point(tempPoint, flag=flag) == 1: flag_is_batch = 0
if flag_is_batch==1 :
# 该区域不存在未知点,加到allbatch里面
all_batch.append(temp_batch)
# 返回值是列表,每个列表元素是一个[(x1,y1),*****25个点组成的batch]
# 到这都是对的
print('still right to here ')
return all_batch
if __name__ == '__main__':
file = '../fin.png'
mask = '../mask.png'
pic_to_inpainted = cv2.imread(filename=file,flags=0)
# mask是加的遮盖,
mask = cv2.imread(filename=mask,flags=0)
height , width = pic_to_inpainted.shape
# 要修补的图像,增添一个flag数组作为示意图,设置数组中要修补的像素点的位置设置为1
flag = np.zeros(shape=pic_to_inpainted.shape)
for ii in range(height):
for jj in range(width):
if mask[ii,jj]==255 :flag[ii,jj] =1
print(flag)
# np.unique(flag,return_counts=True)
# 统计 1 的个数
# 当待填补的图片仍有像素未被填充,执行循环操作
inpainted_iter_number = 0
pic_after_inpainted_save_dir = 'inpainted//'
while( (np.sum(flag == 1)) >0 ):
inpainted_iter_number = inpainted_iter_number + 1
# 保存修复后的照片的位置
pic_after_inpainted_save_place = pic_after_inpainted_save_dir \
+ 'inpainted'+str(inpainted_iter_number)+'.png'
#test
print('test01 is good ')
print('还有多少像素点待填充: ',np.sum(flag==1)) # 1245个mask点
# 遍历得到现存的所有的batch,可以拿来修补像素的5*5小batch,每个batch里面的像素点都是已知的
allbatch = get_all_batch( pic= pic_to_inpainted , flag = flag)
# 那个点最需要补充,获取该像素所在的batch,以及flag
# 沿着flag寻找最迫切需要修补的像素位置,提取其所在的batch,和坐标x,y
the_batch_to_inpaint, pointx, pointy = find_batch(pic_to_inpainted, flag=flag)
# getpoint
# 遍历所有快,得到最小的距离ssd,用该快的中心值填充该像素点
# 初始化距离
min_ssd_distance = 100000000 # 设置初始值很大的一个值
min_ssd_distance = float("inf")
for batch_to_detect in allbatch:
# batch只是像素的位置数据,((x1,y1),(x2,y2)---25个点)
dis , value = ssd_distance_two_batch( batch1 = the_batch_to_inpaint ,
batch2 = batch_to_detect ,
pic = pic_to_inpainted ,
flag_to_deal = flag )
# 得到中心点的值,得到dis距离值
if min_ssd_distance>dis:
# 找到一个更相似的的batch,更新
the_value_to_use =value
min_ssd_distance =dis
print('min is ',min_ssd_distance)
# 用找到的数值填充该像素的值
pic_to_inpainted[pointx , pointy ] = the_value_to_use
# 更新flag数组,少了一个待填充的像素值 ,将该位置已赋值的flag设置为0
flag[pointx , pointy] = 0
print('该块的最小误差是',min_ssd_distance,'填充的值是',the_value_to_use)
cv2.imwrite(pic_after_inpainted_save_place,pic_to_inpainted)
cv2.imshow('No'+str(inpainted_iter_number),pic_to_inpainted)
cv2.waitKey(1)
cv2.destroyAllWindows()
print(pic_to_inpainted)