tensorflow做交叉验证遇到InvalidArgumentError

本文探讨了在TensorFlow中遇到的一个常见错误——在默认图上重复定义占位符导致的InvalidArgumentError。作者通过引入独立的图来解决该问题,确保每次训练都在全新的环境中进行,避免了变量重复定义的问题。
部署运行你感兴趣的模型镜像

原代码的逻辑是train函数构造图,并训练。val_train函数只负责切分训练集。跑代码之后遇到

InvalidArgumentError: You must feed a value for placeholder tensor '*' with dtype float

后来发现是因为每次train函数都是在default_graph上修改,所以两次train的调用,使得sess重复使用了其内部的变量,并且之前定义的placeholder也没有被feed进值。解决方法是使得每次train函数内部都在其新建的Graph中修改构造图。代码如下:

with tf.Graph().as_default():

详情请参考tf.Graph


补充一些我个人的理解:对于我需要交叉验证的问题而言,其实我想要的是每个训练都是在各自单独的图上进行的。Session是进行资源调度分配的模块;Session可以调用Graph,然后按照Graph的路线和输入的数据进行相应的数据流动和更改;Graph里面定义的各种Variable和Operator;如果Graph没有显示定义(如with tf.Graph().as_default(): 就是指在接下来的部分使用一个新的Graph),那都是在session上的默认Graph上操作,即各种定义Variable, Placeholder等操作。那我的情况是多次交叉验证,每次训练都是在默认图上做的操作(构建模型),结果定义了好多个Placeholder等,就遇到了这个问题。解决方式就是文中提到的方式了。
接触tensorflow时间不长,多多指教啦。

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

上述代码出现--------------------------------------------------------------------------- InvalidArgumentError Traceback (most recent call last) Cell In[8], line 7 5 # 训练模型 6 print("开始训练Dreamer模型...") ----> 7 trained_models = train_dreamer(X_red, X_blue, y_red, y_blue) 9 # 使用最后SEQ_LENGTH期数据预测下一期 10 last_red_seq = X_red[-1] Cell In[5], line 116, in train_dreamer(X_red, X_blue, y_red, y_blue) 111 true_indices = tf.cast(tf.where(red_labels > 0.5)[:, 1], DTYPE_INT) 112 # 计算命中数(明确数据类型) 113 hit_count = tf.reduce_sum( 114 tf.cast( 115 tf.reduce_any( --> 116 tf.equal( 117 tf.expand_dims(top6_pred, -1), 118 tf.expand_dims(true_indices, 1) 119 ), axis=-1 120 ), 121 DTYPE_FLOAT 122 ) 123 ) 125 batch_size = tf.shape(red_labels)[0] 126 total_samples += batch_size File ~\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\util\traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs) 151 except Exception as e: 152 filtered_tb = _process_traceback_frames(e.__traceback__) --> 153 raise e.with_traceback(filtered_tb) from None 154 finally: 155 del filtered_tb File ~\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\framework\ops.py:5883, in raise_from_not_ok_status(e, name) 5881 def raise_from_not_ok_status(e, name) -> NoReturn: 5882 e.message += (" name: " + str(name if name is not None else "")) -> 5883 raise core._status_to_exception(e) from None InvalidArgumentError: {{function_node __wrapped__Equal_device_/job:localhost/replica:0/task:0/device:CPU:0}} Incompatible shapes: [64,6,1] vs. [384,1] [Op:Equal] name: ,请继续完善代码
最新发布
10-09
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值