python数据恢复开发_python – 恢复使用迭代器的Tensorflow模型

我有一个使用迭代器训练我的网络的模型;遵循Google现在推荐的新数据集API管道模型.

我读了tfrecord文件,将数据提供给网络,训练得很好,一切顺利,我在训练结束时保存了我的模型,所以我可以在以后运行推理.代码的简化版本如下:

""" Training and saving """

training_dataset = tf.contrib.data.TFRecordDataset(training_record)

training_dataset = training_dataset.map(ds._path_records_parser)

training_dataset = training_dataset.batch(BATCH_SIZE)

with tf.name_scope("iterators"):

training_iterator = Iterator.from_structure(training_dataset.output_types,training_dataset.output_shapes)

next_training_element = training_iterator.get_next()

training_init_op = training_iterator.make_initializer(training_dataset)

def train(num_epochs):

# compute for the number of epochs

for e in range(1,num_epochs+1):

session.run(training_init_op) #initializing iterator here

while True:

try:

images,labels = session.run(next_training_element)

session.run(optimizer,feed_dict={x: images,y_true: labels})

except tf.errors.OutOfRangeError:

saver_name = './saved_models/ucf-model'

print("Finished Training Epoch {}".format(e))

break

""" Restoring """

# restoring the saved model and its variables

session = tf.Session()

saver = tf.train.import_meta_graph(r'saved_models\ucf-model.meta')

saver.restore(session,tf.train.latest_checkpoint('.\saved_models'))

graph = tf.get_default_graph()

# restoring relevant tensors/ops

accuracy = graph.get_tensor_by_name("accuracy/Mean:0") #the tensor that when evaluated returns the mean accuracy of the batch

testing_iterator = graph.get_operation_by_name("iterators/Iterator") #my iterator used in testing.

next_testing_element = graph.get_operation_by_name("iterators/IteratorGetNext") #the GetNext operator for my iterator

# loading my testing set tfrecords

testing_dataset = tf.contrib.data.TFRecordDataset(testing_record_path)

testing_dataset = testing_dataset.map(ds._path_records_parser,num_threads=4,output_buffer_size=BATCH_SIZE*20)

testing_dataset = testing_dataset.batch(BATCH_SIZE)

testing_init_op = testing_iterator.make_initializer(testing_dataset) #to initialize the dataset

with tf.Session() as session:

session.run(testing_init_op)

while True:

try:

images,labels = session.run(next_testing_element)

accuracy = session.run(accuracy,feed_dict={x: test_images,y_true: test_labels}) #error here,x,y_true not defined

except tf.errors.OutOfRangeError:

break

我的问题主要是我恢复模型.如何将测试数据提供给网络?

>当我使用testing_iterator = graph.get_operation_by_name(“iterators / Iterator”),next_testing_element = graph.get_operation_by_name(“iterators / IteratorGetNext”)恢复我的Iterator时,出现以下错误:

GetNext()失败,因为迭代器尚未初始化.确保在获取下一个元素之前已为此迭代器运行初始化程序操作.

>所以我尝试使用以下方法初始化我的数据集:testing_init_op = testing_iterator.make_initializer(testing_dataset)).我收到此错误:AttributeError:’Operation’对象没有属性’make_initializer’

另一个问题是,由于正在使用迭代器,因此不需要在training_model中使用占位符,因为迭代器直接将数据提供给图形.但是这样,当我将数据提供给“准确度”操作时,如何在第3行到最后一行恢复我的feed_dict键?

编辑:如果有人可以建议在迭代器和网络输入之间添加占位符的方法,那么我可以尝试通过评估“准确性”张量来运行图形,同时将数据提供给占位符并完全忽略迭代器.

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值