一、keras class_weight和sample_weight的区别
class_weight:对训练集中的每个类别加一个权重,在model.fit()中设置例如class_weight={0:1,1:10}
sample_weight:对每个样本加权,当数据源类型为tf.data.dataset数据集时,使用model.fit函数不支持sample_weight参数,需在处理样本集时处理。
使用class_weight就会改变loss范围,这样可能会导致训练的稳定性。当Optimizer中的step_size与梯度的大小相关时,将会出现问题。而类似Adam等优化器则不受影响。另外,使用class_weight后的模型的loss大小不能和不使用时做比较。
当label有其他数值时,例如使用class_weight={0:1,1:10,2:20},有可能交叉熵loss值出现负值。
二、tf.data.experimental.make_csv_dataset怎样使用sample_weight
可以在dataset=tf.data.experimental.make_csv_dataset()读入数据后,使用如下代码处理,处理后的dataset中的数据为元组(feature,label,sample_weight),model.fit()支持这种3元组tf.data.dataset数据集类型作为输入。
def sample_weight(y, dict_w):
return tf.where(tf.equal(y, 0), dict_w[0], tf.where(tf.equal(y, 1), dict_w[1], dict_w[2]))
dataset = dataset.map(lambda x, y: (x,y,sample_weight(y, {0:1,1:10,2:20})))