关于keras训练出现:TypeError: __int__ returned non-int (type NoneType)

采用keras训练自己定义的triplet时出现报错

Traceback (most recent call last):
  File "train_similarity.py", line 52, in <module>
    main()
  File "train_similarity.py", line 48, in main
    **train_config)
  File "/data/wwjiang/project/captcha/general_baseline/similarity/src/network/frontend.py", line 228, in train
    max_queue_size=8)
  File "/data/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "/data/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 1418, in fit_generator
    initial_epoch=initial_epoch)
  File "/data/anaconda3/lib/python3.6/site-packages/keras/engine/training_generator.py", line 251, in fit_generator
    callbacks.on_epoch_end(epoch, epoch_logs)
  File "/data/anaconda3/lib/python3.6/site-packages/keras/callbacks.py", line 79, in on_epoch_end
    callback.on_epoch_end(epoch, logs)
  File "/data/wwjiang/project/captcha/general_baseline/similarity/src/network/frontend.py", line 57, in on_epoch_end
    metrics=["accuracy"])
  File "/data/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 342, in compile
    sample_weight, mask)
  File "/data/anaconda3/lib/python3.6/site-packages/keras/engine/training_utils.py", line 404, in weighted
    score_array = fn(y_true, y_pred)
  File "/data/anaconda3/lib/python3.6/site-packages/keras/losses.py", line 73, in sparse_categorical_crossentropy
    return K.sparse_categorical_crossentropy(y_true, y_pred)
  File "/data/anaconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 3347, in sparse_categorical_crossentropy
    logits = tf.reshape(output, [-1, int(output_shape[-1])])
TypeError: __int__ returned non-int (type NoneType)

通过一通google,查源码发现问题所在:
keras 1.12版本中/keras/backend/tensorflow_backend.py文件的3347行

3345: output_shape = output.get_shape()
3347: logits = tf.reshape(output, [-1, int(output_shape[-1])])

而由于动态维度相对静态维度发生改变,应将3347行改为(tenforflow1.14):

logits = tf.reshape(output, [-1, tf.shape(output)[-1]])

ps:这个问题在tensorflow1.14中已修正。
tf.get_shape()获取静态维度
tf.shape 获取动态维度

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值