踩坑 -- 程序run卡住,变量必须定义在初始化之前

我的错误代码是这样的:

# 获取批数据
def get_Batch(data, label, batch_size):
    input_queue = tf.train.slice_input_producer([data, label], shuffle=True)
    x_batch, y_batch = tf.train.batch(input_queue, batch_size=batch_size)
    return x_batch,y_batch

# 5、创建会话
with tf.Session() as sess:
    train_loss = []
    # 全局初始化
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    # 启动队列线程
    coord = tf.train.Coordinator()
    thread = tf.train.start_queue_runners(sess, coord)  # tensor转换成numpy
    # 分批训练,每次训练batch_num个数据
    for i in range(1, epoch_num+1):
        # 获取批数据
        feature_batch, label_batch = get_Batch(train_data, train_label, batch_num)
        feature_batch, label_batch = sess.run([feature_batch, label_batch])
        # 训练
        _,t_loss = sess.run([trainer_func,loss_func],feed_dict={X:train_data,Y:train_label})
        train_loss.append(t_loss)
        print("epoch:%d,loss:%.4g"%(i,t_loss))
    coord.request_stop()
    coord.join(thread)

程序运行到sess.run([feature_batch, label_batch])就无响应了,原因是,feature_batch, label_batch = get_Batch(train_data, train_label, batch_num)这一行是变量定义,在run之前必须初始化,我把计算图只是定义,只有run的时候,才会执行这个概念搞混了,以为我要每次迭代都获取新数据,必须把它放在for里面,变量定义只要放在全部初始化之前就好。修改后的代码如下:

# 获取批数据
def get_Batch(data, label, batch_size):
    input_queue = tf.train.slice_input_producer([data, label], shuffle=True)
    x_batch, y_batch = tf.train.batch(input_queue, batch_size=batch_size)
    return x_batch,y_batch

# 获取批数据
feature_batch, label_batch = get_Batch(train_data, train_label, batch_num)
    
# 5、创建会话
with tf.Session() as sess:
    train_loss = []
    # 全局初始化
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    # 启动队列线程
    coord = tf.train.Coordinator()
    thread = tf.train.start_queue_runners(sess, coord)  # tensor转换成numpy
    # 分批训练,每次训练batch_num个数据
    for i in range(1, epoch_num+1):
        feature_batch, label_batch = sess.run([feature_batch, label_batch])
        # 训练
        _,t_loss = sess.run([trainer_func,loss_func],feed_dict={X:train_data,Y:train_label})
        train_loss.append(t_loss)
        print("epoch:%d,loss:%.4g"%(i,t_loss))
    coord.request_stop()
    coord.join(thread)

记录一下愚蠢的自己~

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值