Keras使用tf.where()时由于维度顺序不对报错
2021-12-10 21:38:30.064007: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:801]
layout failed: Invalid argument: Size of values 0 does not match size of permutation 4 @ fanin shape
inmodel/tf.where/SelectV2-1-TransposeNHWCToNCHW-LayoutOptimizer
例如我希望将model的输出<=0.5的区域置零,这一部分不再参加Loss 的计算
若不将tensor的维度从(N,H,W,C)转换为(N,C,H,W)那么就会出现上面展示的报错信息
epsilon = backend_config.epsilon
outputs_1 = tf.multiply(conv10_1,inputs_mask1)#+epsilon()
outputs_2 = tf.multiply(conv10_2,inputs_mask2)#+epsilon()
temp_mask = outputs_1<=0.5
# temp_mask = tf.transpose(temp_mask,[0,3,1,2])
temp_mask = tf.where(temp_mask,0.0,1.0)
# temp_mask = tf.transpose(temp_mask,[0,2,3,1])
outputs_1 = tf.multiply(outputs_1,temp_mask)+epsilon()
temp_mask = outputs_2<=0.5
# temp_mask = tf.transpose(temp_mask,[0,3,1,2])
temp_mask = tf.where(temp_mask,0.0,1.0)
# temp_mask = tf.transpose(temp_mask,[0,2,3,1])
outputs_2 = tf.multiply(outputs_2,temp_mask)+epsilon()
model = Model(inputs = [inputs,inputs_mask1,inputs_mask2], outputs = [outputs_1,outputs_2])
将代码块中tf.transpose的部分取消注释再运行就不会报错了。