采用keras训练自己定义的triplet时出现报错
Traceback (most recent call last):
File "train_similarity.py", line 52, in <module>
main()
File "train_similarity.py", line 48, in main
**train_config)
File "/data/wwjiang/project/captcha/general_baseline/similarity/src/network/frontend.py", line 228, in train
max_queue_size=8)
File "/data/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "/data/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 1418, in fit_generator
initial_epoch=initial_epoch)
File "/data/anaconda3/lib/python3.6/site-packages/keras/engine/training_generator.py", line 251, in fit_generator
callbacks.on_epoch_end(epoch, epoch_logs)
File "/data/anaconda3/lib/python3.6/site-packages/keras/callbacks.py", line 79, in on_epoch_end
callback.on_epoch_end(epoch, logs)
File "/data/wwjiang/project/captcha/general_baseline/similarity/src/network/frontend.py", line 57, in on_epoch_end
metrics=["accuracy"])
File "/data/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 342, in compile
sample_weight, mask)
File "/data/anaconda3/lib/python3.6/site-packages/keras/engine/training_utils.py", line 404, in weighted
score_array = fn(y_true, y_pred)
File "/data/anaconda3/lib/python3.6/site-packages/keras/losses.py", line 73, in sparse_categorical_crossentropy
return K.sparse_categorical_crossentropy(y_true, y_pred)
File "/data/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 3347, in sparse_categorical_crossentropy
logits = tf.reshape(output, [-1, int(output_shape[-1])])
TypeError: __int__ returned non-int (type NoneType)
通过一通google,查源码发现问题所在:
keras 1.12版本中/keras/backend/tensorflow_backend.py
文件的3347行
3345: output_shape = output.get_shape()
3347: logits = tf.reshape(output, [-1, int(output_shape[-1])])
而由于动态维度相对静态维度发生改变,应将3347行改为(tenforflow1.14):
logits = tf.reshape(output, [-1, tf.shape(output)[-1]])
ps:
这个问题在tensorflow1.14中已修正。
tf.get_shape()
获取静态维度
tf.shape
获取动态维度