U-Net网络缺陷检测 样本像素不均衡问题
U-Net 的网络设计如上图所示,
或者参考文章Automatic Metallic Surface Defect Detection and Recognition with Convolutional Neural Networks来设计你的网络。
现模型输入 input_image 1 x 512 x 512 x 1 input_label 1 x 512 x 512
现模型输出 prediction 1 x 512 x 512 x 2
In general, a captured image of the metallic surface has more background pixels than defective pixels.
通常,金属表面的捕获图像具有比缺陷像素更多的背景像素。
样本均衡 postive:negative ≈ 1 : 3,缺陷检测往往面临样本不均衡问题 postive:negative != 1 : 3
针对样本不均衡问题
solution-1 focal loss
def focal_loss(labels, logits, gamma=2, alpha=0.25):
labels = tf.one_hot(labels, depth=2, on_value=None, off_value=None, axis=None, dtype=None, name=None)
# 防止为0
epsilon = 1.e-3
logits = tf.clip_by_value(logits, epsilon, 1. - epsilon)
# 交叉熵
cross_entropy = -labels * tf.log(logits)
# focal loss 权重交叉熵模型
focal_loss = tf.pow(1 - logits, gamma) * cross_entropy
# loss
loss = tf.reduce_mean(tf.reduce_sum(focal_loss, axis=-1))
return loss
solution-2 re-weight the imbalanced classes
参考文章 Automatic Metallic Surface Defect Detection and Recognition with Convolutional Neural Networks的思路
def imbalance_cross_entropy(labels, logits):
labels_one_hot = tf.one_hot(labels, depth=2, on_value=None, off_value=None, axis=None, dtype=None, name=None)
# 防止为0
epsilon = 1.e-3
logits = tf.clip_by_value(logits, epsilon, 1. - epsilon)
# 交叉熵
cross_entropy = -labels_one_hot * tf.log(logits)
cross_entropy_sum = tf.reduce_sum(cross_entropy, axis=-1)
# imbalance weight
labels_reshape = tf.reshape(labels, [-1])
weight01 = tf.constant(0.1, dtype=tf.float32, shape=labels_reshape.shape)
weight09 = tf.constant(0.9, dtype=tf.float32, shape=labels_reshape.shape)
labels_weight = tf.where(tf.equal(labels_reshape, 1), weight01, weight09)
labels_imbalance = tf.reshape(labels_weight, (1, 512, 512))
# imbalance_cross_entropy
loss = cross_entropy_sum*labels_imbalance
imbalance_cross_entropy_loss = tf.reduce_mean(loss)
return imbalance_cross_entropy_loss
这里设置的 re-weight 参数是0.1,0.9
原文链接:https://blog.csdn.net/Frank_Zer/article/details/112363825