问题描述
调用 tf.keras.Model.fit()
时可以带入参数 class_weight
,对训练时的损失函数进行加权,用于告诉模型更需要关注某些代表性不足的类别的样本
如数据集 flower_photos.tgz
数据集是不平衡的,即每个类的样本数不同(广义来说是平衡的,比例偏差不大)
分类 | 图片数 |
---|---|
daisy | 633 |
dandelion | 898 |
roses | 641 |
sunflowers | 699 |
tulips | 799 |
解决方案
统计不同分类的图片数量 tf.keras.preprocessing.image.ImageDataGenerator.flow_from_directory()
import numpy as np
import tensorflow as tf
directory = 'flower_photos'
datagen = tf.keras.preprocessing.image.ImageDataGenerator()
data = datagen.flow_from_directory(directory)
unique = np.unique(data.classes, return_counts=True)
labels_dict = dict(zip(unique[0], unique[1]))
print(labels_dict)
# Found 3670 images belonging to 5 classes.
# {0: 633, 1: 898, 2: 641, 3: 699, 4: 799}
统计不同分类的图片数量 tf.keras.preprocessing.image_dataset_from_directory()
(遍历 dataset 时耗时很久,暂时找不到解决方案)
import numpy as np
import tensorflow as tf
directory = 'flower_photos'
dataset = tf.keras.preprocessing.image_dataset_from_directory(directory)
classes = np.concatenate([y for x, y in dataset], axis=0)
unique = np.unique(classes, return_counts=True)
labels_dict = dict(zip(unique[0], unique[1]))
print(labels_dict)
# Found 3670 files belonging to 5 classes.
# {0: 633, 1: 898, 2: 641, 3: 699, 4: 799}
计算不同类别权重
import math
def get_class_weight(labels_dict):
"""计算数据集不同类别的占比权重
>>> get_class_weight({0: 633, 1: 898, 2: 641, 3: 699, 4: 799})
{0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0}
>>> get_class_weight({0: 5, 1: 78, 2: 2814, 3: 7914})
{0: 7.366950709511269, 1: 4.619679795255778, 2: 1.034026384271035, 3: 1.0}
"""
total = sum(labels_dict.values())
max_num = max(labels_dict.values())
mu = 1.0 / (total / max_num)
class_weight = dict()
for key, value in labels_dict.items():
score = math.log(mu * total / float(value))
class_weight[key] = score if score > 1.0 else 1.0
return class_weight
labels_dict = {0: 633, 1: 898, 2: 641, 3: 699, 4: 799} # 平衡数据集
print(get_class_weight(labels_dict))
# {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0}
labels_dict = {0: 5, 1: 78, 2: 2814, 3: 7914} # 不平衡数据集
print(get_class_weight(labels_dict))
# {0: 1.0, 1: 3.749820767859636, 2: 1.0, 3: 3.749820767859636, 4: 1.0, 5: 2.5931008483842453, 6: 1.0, 7: 2.5931008483842453}