实验总结来看:
tf.py_function()
输入:EagerTensor
返回:必须是Tensor
tf.py_func()
输入:ndarray
返回:必须是ndarray
实验代码
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
import numpy as np
print(tf.__version__)
def get_batch_data(lines):
print(type(lines))
#called by py_function
#received an EgarTensor
#and the returned value must be Tensor
image_resized = tf.constant(3.0, dtype=tf.float32)
label = tf.constant(2, dtype=tf.int32)
return image_resized, label
def get_batch_data2(lines):
print(type(lines))
#called by py_func
#received an ndarry
#and the returned value must be ndarry
image_resized = np.arange(2, dtype=np.float32)
label = np.arange(2, dtype=np.int32)
return image_resized, label
#tensorflow 1.x
t_file = ['train.txt']
t_size = 6
t_batch_size = 2
train_dataset = tf.data.TextLineDataset(t_file)
train_dataset = train_dataset.shuffle(t_size)
train_dataset = train_dataset.batch(t_batch_size)
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.map(
lambda x: tf.py_function(func=get_batch_data, inp=[x], Tout=[tf.float32, tf.int32]))
#lambda x: tf.py_func(func=get_batch_data2, inp=[x], Tout=[tf.float32, tf.int32]))
print(train_dataset.output_types)
print(train_dataset.output_shapes)
iterator = tf.data.Iterator.from_structure(
train_dataset.output_types,
train_dataset.output_shapes)
next_element = iterator.get_next()
train_init_op = iterator.make_initializer(train_dataset)
with tf.Session() as sess:
sess.run(train_init_op)
for _ in range(6):
element = sess.run(next_element)
#print(element)
#print(element.shape)
使用tf.py_function()输出
(tf.float32, tf.int32)
(TensorShape(None), TensorShape(None))
<class 'tensorflow.python.framework.ops.EagerTensor'>
<class 'tensorflow.python.framework.ops.EagerTensor'>
<class 'tensorflow.python.framework.ops.EagerTensor'>
<class 'tensorflow.python.framework.ops.EagerTensor'>
<class 'tensorflow.python.framework.ops.EagerTensor'>
<class 'tensorflow.python.framework.ops.EagerTensor'>
使用tf.py_func()输出
(tf.float32, tf.int32)
(TensorShape(None), TensorShape(None))
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
Reference:
https://github.com/tensorflow/docs/blob/r1.15/site/en/guide/datasets.md
https://blog.csdn.net/poson/article/details/106601539
https://blog.csdn.net/u012073033/article/details/89209894
https://blog.csdn.net/qq_27825451/article/details/105247211