使用tf.data.Dataset和迭代器获取数据,训练总是提前结束。

 

 

这个问题困扰了我一个星期终于解决了。主要原因在于sess.run()的运用。

注意!!!sess.run()的下面两种方式:

sess.run([a, b])

和 

sess.run(a)
sess.run(b)

不一样!!!!尤其是在b与a有前后连接的时候,当他们有连接的时候,第一种方式是同时运行,也就是在同一步里运行,而后一种方式则是运行完a后,当a处在下一个状态时再运行b。而一旦b是由a得到的,那么在运行b的时候就是又运行了a的下一个状态。

在我的代码中,我用了可初始化迭代器和他的get_next() 方法来获取dataset管道中的数据,再将获取的数据投入模型进行训练,训练完之后想要print损失值,于是我单独sess.run(loss)。在这之前训练网络的时候肯定要对训练操作进行sess.run,这就相当于采用了上面第二种方式,导致的结果就是在一个step中,我多次sess.run()了迭代器,迭代器当然会跳跃着读取训练数据,导致训练步骤提前终止。

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值