在运行unet网络的时候报出如下错误:
ValueError: Tensor conversion requested dtype int32 for Tensor with dtype int64: <tf.Tensor 'loss/activation_1_loss/lossFunc/ArgMax:0' shape=(?, ?) dtype=int64>
TypeError: Input 'y' of 'Equal' Op has type int64 that does not match type int32 of argument 'x'.
原因:根据错误提示,原因是类型不匹配导致的。
检查代码,锁定到classSelectors = K.argmax(true, axis=axis)这一行代码,查看类型:
classSelectors = K.argmax(true, axis=axis)
print(classSelectors.dtype)
所以发现错误就在这一行代码。
做如下修改:转换数据类型
classSelectors = tf.cast(classSelectors, tf.int32)
再次查看类型:
print(type(classSelectors))
print(classSelectors.dtype)
可以看出修改完成,再次运行代码,错误成功解决!
公众号:机器学习实战python