tensorflow.data.Dataset的使用

在使用tensorflow编程时,如何将输入数据传入计算图中是一个需要重点关注的问题,但是tensorflow中提供了库函数,将输入进行了封装,而我们只需要调用函数接口即可。主要的库函数在tensorflow.data.Dataset中。

1.输入数据确定

import tensorflow as tf
x=np.array([[1],[2],[3],[4]])
y=np.array([[1],[2],[3],[4]])
x_placeholder=tf.placeholder(dtype=tf.int32)
y_placeholder=tf.placeholder(dtype=tf.int32)
dataset=tf.data.Dataset.from_tensor_slices((x,y))
def func(x,y):
    return x*1,y
dataset=dataset.map(func)
dataset=dataset.shuffle(2)
dataset=dataset.repeat()
dataset=dataset.batch(3)
iterator=dataset.make_initializable_iterator()
result_x,result_y=iterator.get_next()
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    # sess.run(iterator.initializer,feed_dict={x_placeholder:x,y_placeholder:y})
    sess.run(iterator.initializer)
    for _ in range(1):
        result =sess.run([result_x,result_y])
        print(result)
        result=sess.run([result_x,result_y])
        print(result)

下面将解释每个函数的用法。
1) from_tensor_slices(tensors)
根据输入的tensors创建dataset
2)repeat(count=None)
表示数据集循环遍历的次数,None表示无限循环,count=epoch表示遍历数据集epoch次
3)shuffle(buffer_size,seed=None,reshuffle_each_iteration=None)
将数据集打乱,buffer_size意味着创建一个大小为buffer_size的缓冲区或者队列,每次从dataset中读取buffer_size个数据,并将缓冲区或者队列里的数据打乱再输出
4)batch(batch_size,drop_remainder=False)
从缓冲区中取出batch_size个数据
5)make_initializable_iterator(shared_name=None)
创建迭代器,通过迭代器来获取一个batch的数据

dataset = ...
iterator = dataset.make_initializable_iterator()
# ...
sess.run(iterator.initializer)

这里需要注意的是,在计算图上进行计算时,必须要执行迭代器的初始化!!!
6)get_next(name=None)
通过调用迭代器的get_next函数来得到一个batch的数据
7)map(map_func,num_parallel_calls=None)
map函数实现的功能是对输入数据进行处理,可以自行实现map_func,map_func函数的参数就是dataset的切片,输入参数个数和返回参数个数必须相等

2.输入数据不确定
如果输入数据不确定,则需要使用placehoder占位符。

import tensorflow as tf
x=np.array([[1],[2],[3],[4]])
y=np.array([[1],[2],[3],[4]])
x_placeholder=tf.placeholder(dtype=tf.int32)
y_placeholder=tf.placeholder(dtype=tf.int32)
dataset=tf.data.Dataset.from_tensor_slices((x,y))
def func(x,y):
    return x*1,y
dataset=dataset.map(func)
dataset=dataset.shuffle(2)
dataset=dataset.repeat()
dataset=dataset.batch(3)
iterator=dataset.make_initializable_iterator()
result_x,result_y=iterator.get_next()
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    sess.run(iterator.initializer,feed_dict={x_placeholder:x,y_placeholder:y})
    # sess.run(iterator.initializer)
    for _ in range(1):
        result =sess.run([result_x,result_y])
        print(result)
        result=sess.run([result_x,result_y])
        print(result)

sess.run(iterator.initializer时,就把输入数据传入就可以了

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值