深度学习中的图像分布是指图像特征在不同类别间的分布统计,图像纹理特征是较难描述的的,借鉴机器学习中对图像的分类方法,图像特征有基于灰度或梯度统计的GLCM、LBP、灰度直方图、HOG、SIFT、感知hash值 等特征,也有基于颜色空间的RGB颜色分布(RGB值分区)、图像分块特征码等。
通过分析可描述的特征在类别间的分布,可以平衡非类别决定特征在各个类别中的平衡,使模型在反向传播中不会取巧,尽可能学习到模型的本质特征。示例:如使用AI区分手机和充电宝,手机的照片都是竖着的,充电宝的照片都是横着的,如果不改善形态轮廓在类别间的分布,AI就会以为横着的就是充电宝,竖着的就是手机,这很容易导致过拟合。
以下代码讲述了根据图像的感知hash值进行归类,查看其轮廓特征的分布状态。
核心思想:将hash值相近的数据当做一类数据,可以分析出各种轮廓数据在不同类别间的分布
1、构建计算hash值的类,并判断hash值是否已经出现过
import cv2,os
import numpy as np
class HashCacl:
def cal_hash_code(self, cur_gray):
s_img = cv2.resize(cur_gray, dsize=(8, 8))
img_mean = cv2.mean(s_img)
return s_img > img_mean[0]
def cal_phash_code(self, cur_gray):
# 缩小至32*32
m_img = cv2.resize(cur_gray, dsize=(32, 32))
# 浮点型用于计算
m_img = np.float32(m_img)
# 离散余弦变换,得到dct系数矩阵
img_dct = cv2.dct(m_img)
img_mean = cv2.mean(img_dct[0:8, 0:8])
# 返回一个8*8bool矩阵
return img_dct[0:8, 0:8] > img_mean[0]
def cal_dhash_code(self, cur_gray):
# dsize=(width, height)
m_img = cv2.resize(cur_gray, dsize=(9, 8))
m_img = np.int8(m_img)
# 得到8*8差值矩阵
m_img_diff = m_img[:, :-1] - m_img[:, 1:]
return m_img_diff > 0
def cal_hamming_distance(self, model_hash_code, search_hash_code):
# 返回不相同的个数
diff = model_hash_code != search_hash_code
return diff.sum()
def hash_in_warehouse(self, Warehouse, search_hash_code):
in_warehouse=False
index=-1
result=dict()
for i in range(len(Warehouse)):
dis=self.cal_hamming_distance(Warehouse[i],search_hash_code)
if dis<8:
in_warehouse=True
index=i
result[dis]=index
result=sorted(result.items())
if in_warehouse==False:
Warehouse.append(search_hash_code)
index=len(Warehouse)
else:
index=result[0][-1]
return Warehouse,in_warehouse,index
2、批量计算目录下文件的hash值,相近的hash值采用相同的编号
传入:目录,hash值库,储存结果的编号
返回:更新后的hash值库,各类hash值的出现的频率
def data_cacl(path_dir,Warehouse=[],type_counter=[]):
Hash=HashCacl()
#Warehouse 储存模板数据的hash
#type_counter 储存相应模板数据所出现的类型
count=0
for (root, dirs, files) in os.walk(path_dir):
if files:
for f in files:
bmp_path = os.path.join(root,f)
if 'hash' not in bmp_path and ('.bmp' in bmp_path):
bmp=cv2.imdecode(np.fromfile(bmp_path,dtype=np.uint8),0)
search_hash_code=Hash.cal_hash_code(bmp)
old_len=len(Warehouse)
Warehouse,in_warehouse,index=Hash.hash_in_warehouse(Warehouse, search_hash_code)
new_bmp_path=bmp_path.replace(f, 'hash%i_%s'%(index,f))
count+=1
if not in_warehouse:
type_counter.append(1)
os.rename(bmp_path, new_bmp_path)
if old_len!=len(Warehouse):
print(count,old_len,new_bmp_path)
else:
type_counter[index]+=1
#os.unlink(jpg_path)
#os.unlink(bmp_path)
os.rename(bmp_path, new_bmp_path)
if old_len!=len(Warehouse):
print(count,old_len,new_bmp_path)
return Warehouse,type_counter
3、调用函数,可视化hash值的分布
mport matplotlib.pyplot as plt
import matplotlib
# 设置matplotlib正常显示中文和负号
matplotlib.rcParams['font.sans-serif']=['SimHei'] # 用黑体显示中文
matplotlib.rcParams['axes.unicode_minus']=False # 正常显示负号
if __name__ == '__main__':
path=r"类别1"
Warehouse=[]
counter1=[]
Warehouse,OK_counter=data_cacl(path,Warehouse,counter1)
path=r"类别2"
counter2=[0]*len(counter1)
Warehouse2,NG_counter=data_cacl(path,Warehouse,counter2)
if True:
plt.figure(figsize=(16, 6), dpi=80)
#
plt.ylim(ymin =0, ymax=100)
plt.xlim(xmin =0, xmax=len(counter1))
plt.xticks(range(0, len(counter1), 50))
plt.bar(range(len(OK_counter)), OK_counter)
plt.show()
plt.figure(figsize=(28, 6), dpi=80)
plt.ylim(ymin =0, ymax=100)
plt.xlim(xmin =0, xmax=len(counter2))
plt.xticks(range(0, len(counter2), 50))
plt.bar(range(len(counter2)), counter2)
plt.show()