深度学习 9 图像分布查看

该博客探讨了深度学习中图像特征分布的重要性,特别是如何通过分析图像的感知哈希值来平衡不同类别间的特征分布。文中提到,如果图像特征分布不平衡,可能导致模型过拟合。提供的代码实现了一个计算和分类图像感知哈希值的系统,用于检查轮廓特征在类别间的分布,并通过批量处理目录下文件的hash值,确保相近hash值的图像被赋予相同编号,从而调整类别平衡。最后,通过可视化展示不同类别中hash值的分布情况。
摘要由CSDN通过智能技术生成

         深度学习中的图像分布是指图像特征在不同类别间的分布统计,图像纹理特征是较难描述的的,借鉴机器学习中对图像的分类方法,图像特征有基于灰度或梯度统计的GLCM、LBP、灰度直方图、HOG、SIFT、感知hash值 等特征,也有基于颜色空间的RGB颜色分布(RGB值分区)、图像分块特征码等。

        通过分析可描述的特征在类别间的分布,可以平衡非类别决定特征在各个类别中的平衡,使模型在反向传播中不会取巧,尽可能学习到模型的本质特征。示例:如使用AI区分手机和充电宝,手机的照片都是竖着的,充电宝的照片都是横着的,如果不改善形态轮廓在类别间的分布,AI就会以为横着的就是充电宝,竖着的就是手机,这很容易导致过拟合。

        以下代码讲述了根据图像的感知hash值进行归类,查看其轮廓特征的分布状态。

核心思想:将hash值相近的数据当做一类数据,可以分析出各种轮廓数据在不同类别间的分布

1、构建计算hash值的类,并判断hash值是否已经出现过


import cv2,os
import numpy as np
 
class HashCacl:
    def cal_hash_code(self, cur_gray):
        s_img = cv2.resize(cur_gray, dsize=(8, 8))
        img_mean = cv2.mean(s_img)
        return s_img > img_mean[0]
    def cal_phash_code(self, cur_gray):
        # 缩小至32*32
        m_img = cv2.resize(cur_gray, dsize=(32, 32))
        # 浮点型用于计算
        m_img = np.float32(m_img)
        # 离散余弦变换,得到dct系数矩阵
        img_dct = cv2.dct(m_img)
        img_mean = cv2.mean(img_dct[0:8, 0:8])
        # 返回一个8*8bool矩阵
        return img_dct[0:8, 0:8] > img_mean[0]
 
    def cal_dhash_code(self, cur_gray):
        # dsize=(width, height)
        m_img = cv2.resize(cur_gray, dsize=(9, 8))
        m_img = np.int8(m_img)
        # 得到8*8差值矩阵
        m_img_diff = m_img[:, :-1] - m_img[:, 1:]
        return m_img_diff > 0
 
    def cal_hamming_distance(self, model_hash_code, search_hash_code):
        # 返回不相同的个数
        diff = model_hash_code != search_hash_code
        return diff.sum()
    def hash_in_warehouse(self, Warehouse, search_hash_code):
        in_warehouse=False
        index=-1
        result=dict()
        for i in range(len(Warehouse)):
            dis=self.cal_hamming_distance(Warehouse[i],search_hash_code)
            if dis<8:
               in_warehouse=True
               index=i
               result[dis]=index
        result=sorted(result.items())
        if in_warehouse==False:
            Warehouse.append(search_hash_code)
            index=len(Warehouse)
        else:
            index=result[0][-1]
        return Warehouse,in_warehouse,index

2、批量计算目录下文件的hash值,相近的hash值采用相同的编号

传入:目录,hash值库,储存结果的编号

返回:更新后的hash值库,各类hash值的出现的频率

def data_cacl(path_dir,Warehouse=[],type_counter=[]):
    Hash=HashCacl()
    #Warehouse    储存模板数据的hash
    #type_counter 储存相应模板数据所出现的类型
    count=0
    for (root, dirs, files) in os.walk(path_dir):
        if files:
            for f in files:
                bmp_path = os.path.join(root,f)
                if 'hash' not in bmp_path and ('.bmp' in bmp_path):
                    bmp=cv2.imdecode(np.fromfile(bmp_path,dtype=np.uint8),0)
                    search_hash_code=Hash.cal_hash_code(bmp)
                    old_len=len(Warehouse)
                    Warehouse,in_warehouse,index=Hash.hash_in_warehouse(Warehouse, search_hash_code)
                    new_bmp_path=bmp_path.replace(f, 'hash%i_%s'%(index,f))
                    count+=1
                    if not in_warehouse:
                        type_counter.append(1)
                        os.rename(bmp_path, new_bmp_path)
                        if old_len!=len(Warehouse):
                            print(count,old_len,new_bmp_path)
                    else:
                        type_counter[index]+=1
                        #os.unlink(jpg_path)
                        #os.unlink(bmp_path)
                        os.rename(bmp_path, new_bmp_path)
                        if old_len!=len(Warehouse):
                            print(count,old_len,new_bmp_path)
    return Warehouse,type_counter

3、调用函数,可视化hash值的分布

mport matplotlib.pyplot as plt
import matplotlib
# 设置matplotlib正常显示中文和负号
matplotlib.rcParams['font.sans-serif']=['SimHei']   # 用黑体显示中文
matplotlib.rcParams['axes.unicode_minus']=False     # 正常显示负号
if __name__ == '__main__':
    path=r"类别1"
    Warehouse=[]
    counter1=[]
    Warehouse,OK_counter=data_cacl(path,Warehouse,counter1)
    path=r"类别2"
    counter2=[0]*len(counter1)
    Warehouse2,NG_counter=data_cacl(path,Warehouse,counter2)
if True:   
    plt.figure(figsize=(16, 6), dpi=80)
    #
    plt.ylim(ymin =0, ymax=100)
    plt.xlim(xmin =0, xmax=len(counter1))
    plt.xticks(range(0, len(counter1), 50))
    plt.bar(range(len(OK_counter)), OK_counter)  
    plt.show()
    
    plt.figure(figsize=(28, 6), dpi=80)
    plt.ylim(ymin =0, ymax=100)
    plt.xlim(xmin =0, xmax=len(counter2))
    plt.xticks(range(0, len(counter2), 50))
    plt.bar(range(len(counter2)), counter2)  
    plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

万里鹏程转瞬至

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值