Python计算数据集不同类别的权重class_weight

37 篇文章 1 订阅

问题描述

调用 tf.keras.Model.fit() 时可以带入参数 class_weight,对训练时的损失函数进行加权,用于告诉模型更需要关注某些代表性不足的类别的样本

如数据集 flower_photos.tgz

数据集是不平衡的,即每个类的样本数不同(广义来说是平衡的,比例偏差不大)

分类图片数
daisy633
dandelion898
roses641
sunflowers699
tulips799




解决方案

统计不同分类的图片数量 tf.keras.preprocessing.image.ImageDataGenerator.flow_from_directory()

import numpy as np
import tensorflow as tf

directory = 'flower_photos'
datagen = tf.keras.preprocessing.image.ImageDataGenerator()
data = datagen.flow_from_directory(directory)
unique = np.unique(data.classes, return_counts=True)
labels_dict = dict(zip(unique[0], unique[1]))
print(labels_dict)
# Found 3670 images belonging to 5 classes.
# {0: 633, 1: 898, 2: 641, 3: 699, 4: 799}

统计不同分类的图片数量 tf.keras.preprocessing.image_dataset_from_directory() (遍历 dataset 时耗时很久,暂时找不到解决方案)

import numpy as np
import tensorflow as tf

directory = 'flower_photos'
dataset = tf.keras.preprocessing.image_dataset_from_directory(directory)
classes = np.concatenate([y for x, y in dataset], axis=0)
unique = np.unique(classes, return_counts=True)
labels_dict = dict(zip(unique[0], unique[1]))
print(labels_dict)
# Found 3670 files belonging to 5 classes.
# {0: 633, 1: 898, 2: 641, 3: 699, 4: 799}

计算不同类别权重

import math


def get_class_weight(labels_dict):
    """计算数据集不同类别的占比权重

    >>> get_class_weight({0: 633, 1: 898, 2: 641, 3: 699, 4: 799})
    {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0}
    >>> get_class_weight({0: 5, 1: 78, 2: 2814, 3: 7914})
    {0: 7.366950709511269, 1: 4.619679795255778, 2: 1.034026384271035, 3: 1.0}
    """
    total = sum(labels_dict.values())
    max_num = max(labels_dict.values())
    mu = 1.0 / (total / max_num)
    class_weight = dict()
    for key, value in labels_dict.items():
        score = math.log(mu * total / float(value))
        class_weight[key] = score if score > 1.0 else 1.0
    return class_weight


labels_dict = {0: 633, 1: 898, 2: 641, 3: 699, 4: 799}  # 平衡数据集
print(get_class_weight(labels_dict))
# {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0}
labels_dict = {0: 5, 1: 78, 2: 2814, 3: 7914}  # 不平衡数据集
print(get_class_weight(labels_dict))
# {0: 1.0, 1: 3.749820767859636, 2: 1.0, 3: 3.749820767859636, 4: 1.0, 5: 2.5931008483842453, 6: 1.0, 7: 2.5931008483842453}




参考文献

  1. How to set class weights for imbalanced classes in Keras?
  2. Get unique values in a list of numpy arrays
  3. NumPy: Count the frequency of unique values in numpy array
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

XerCis

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值