【炼丹随记】样本不均衡时,class_weight 的计算

1 语义分割时的样本均衡

1.1 计算方法

语义分割时,如果样本比例失衡,就需要设置 class_weight来平衡损失,那么该如何计算呢?

  • 直观的想到是,先获取图片的每个类别的像素点的个数之间的比例,然后用1去除以。比如:
    class1 : class2 : class3 = 100 : 10 : 1,那么 weight1 : weight2 : weight3 = 1:10:100。但这个比值偏差太大,放到loss中训练并不能得到一个好的结果。

  • OK的操作:
    一个比较不错的计算方式:获取每个类别的像素的总值,数值大的类别应该有偏小的权重(注意这里是偏小,小一点点,不能说是小的离谱)。所以需要一个递减且 当自变量增长很大时因变量依然递减 且递减缓慢的函数,于是使用 1/log(x),在变量 x 非常大时,1/log(x) 符合递减且递减的很缓慢的特征。这样就可以得到合适的权重用于神经网络的训练。
    方法来源于【https://github.com/openseg-group/OCNet.pytorch/issues/14】。

    这个过程我们需要知道的信息:

    • 类别数为 class_num
    • 每个类别的的像素点的个数为 pixel_count:分别统计训练集中所有不同类别的像素个数

    计算的代码为:

    import numpy as np
    
    """========权重的计算============"""
    def get_weight(class_num, pixel_count):
    
        W = 1 / np.log(pixel_count)
    
        #除以np.sum(W),是将权重归一化,让每个类别的权重相加为1。
        #乘以class_num,是为了让权重中的每个类别的权重值接近1,使网络在一个正常的水平上进行训练。
        W = class_num * W/np.sum(W)  
     
        return W
    
    if __name__ == "__main__":
    
        """========测试============"""
        base = 5000
        pixel_count = np.array([100, 10, 1])*base  
        W = get_weight(3, pixel_count)
    
        print(W)  # [0.79925325 0.96934441 1.23140234]
    

    此时,class1 : class2 : class3 = 100 : 10 : 1,它们的权重设置为weight1 : weight2 : weight3 = 0.79925325 : 0.96934441 : 1.23140234。


1.2 完整代码

这里同时统计图片的均值、方差、权重。实现过程需要注意:

  • opencv读取出来的是 BGR,注意神经网络中的均值方差是 BGR or RGB ?
  • 我们准备数据时,如果图片尺寸大小不一,计算权重时,应将图片处理成同尺寸,然后再进行计算。我这里图像的尺寸不同但比例相同,所以使用了简单的resize。resize时要使用最近邻插值,不要给标签带来新的数值
  • 读取label.png时,确保是正确读取单通道图片,并且像素正确


代码如下:

from random import shuffle
import numpy as np
import os
import cv2

def get_weight(class_num, pixel_count):
    W = 1 / np.log(pixel_count)
    W = class_num * W / np.sum(W)
    return W

def get_MeanStdWeight(class_num=12, size=(640,360)):

    image_path = "../datasets/data/train/"  ## 训练输入的color图片
    label_path = "../datasets/label/train/"   ## 训练标签的mask图片
    
    namelist = os.listdir(image_path)
    """========如果提供的是txt文本,保存的训练集中的namelist=============="""
    # file_name = "../datasets/train.txt"
    # with open(file_name,"r") as f:
    #     namelist = f.readlines()
    # namelist = [file[:-1].strip() for file in namelist]    ## 解析出对应的namelist就可以
    """==============================================================="""

    MEAN = []
    STD = []
    pixel_count = np.zeros((class_num,1))

    for i in range(len(namelist)):
        print(i, os.path.join(image_path, namelist[i]))

        image = cv2.imread(os.path.join(image_path, namelist[i]))[:,:,::-1]
        image = cv2.resize(image, size, interpolation=cv2.INTER_NEAREST)
        print(image.shape)

        mean = np.mean(image, axis=(0,1))
        std = np.std(image, axis=(0,1))
        MEAN.append(mean)
        STD.append(std)

        label = cv2.imread(os.path.join(label_path, namelist[i]), 0)
        label = cv2.resize(label, size, cv2.INTER_LINEAR)

        label_uni = np.unique(label)
        for m in label_uni:
            pixel_count[m] += np.sum(label == m)

    MEAN = np.mean(MEAN, axis=0) / 255.0
    STD = np.mean(STD, axis=0) / 255.0

    weight = get_weight(class_num, pixel_count.T)
    print(MEAN)
    print(STD)
    print(weight)

    return MEAN, STD, weight

2 对于目标检测的class_weight的计算

2024.1.3 补充。顺手把目标检测的class_weight的计算 贴上:

import numpy as np
from a0_base import *

def get_weight(class_num, pixel_count):
    W = 1 / np.log(pixel_count)
    W = class_num * W / np.sum(W)
    return W

def get_MeanStdWeight(file_name, class_num):
    with open(file_name,"r") as f:
        namelist = f.readlines()
    namelist = [file.strip() for file in namelist]

    BOX_count = np.zeros((class_num,1))
    for i, name in enumerate(namelist):
        txtpath = name.replace("/images/", "/labels/").replace(".jpg", ".txt")
        print(i, txtpath)

        label = np.loadtxt(txtpath).reshape(-1,5)
        label_uni = np.unique(label[:,0])
        for m in label_uni:
            BOX_count[int(m)] += np.sum(label == m)

    weight = get_weight(class_num, BOX_count.T)
    print(weight)

    return weight

file_name = "../train.txt"
get_MeanStdWeight(file_name, len(CLASSNAME))
  • 3
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值