tensorflow1.15之间tf.py_function和tf.py_func的区别

实验总结来看:

tf.py_function()

输入:EagerTensor

返回:必须是Tensor

tf.py_func()

输入:ndarray

返回:必须是ndarray

 

实验代码

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
import numpy as np

print(tf.__version__)

def get_batch_data(lines):
    print(type(lines))
    #called by py_function
    #received an EgarTensor
    #and the returned value must be Tensor
    
    image_resized = tf.constant(3.0, dtype=tf.float32)
    label = tf.constant(2, dtype=tf.int32)
    return image_resized, label

def get_batch_data2(lines):
    print(type(lines))
    #called by py_func
    #received an ndarry
    #and the returned value must be ndarry
    image_resized = np.arange(2, dtype=np.float32)
    label = np.arange(2, dtype=np.int32)
    return image_resized, label

#tensorflow 1.x
t_file = ['train.txt']
t_size = 6
t_batch_size = 2

train_dataset = tf.data.TextLineDataset(t_file)
train_dataset = train_dataset.shuffle(t_size)
train_dataset = train_dataset.batch(t_batch_size)
train_dataset = train_dataset.repeat()

train_dataset = train_dataset.map(
    lambda x: tf.py_function(func=get_batch_data, inp=[x], Tout=[tf.float32, tf.int32]))
    #lambda x: tf.py_func(func=get_batch_data2, inp=[x], Tout=[tf.float32, tf.int32]))

print(train_dataset.output_types)
print(train_dataset.output_shapes)

iterator = tf.data.Iterator.from_structure(
    train_dataset.output_types,
    train_dataset.output_shapes)

next_element = iterator.get_next()

train_init_op = iterator.make_initializer(train_dataset)

with tf.Session()  as sess:
      sess.run(train_init_op)
      for _ in range(6):
        element = sess.run(next_element)
        #print(element)
        #print(element.shape)

使用tf.py_function()输出

(tf.float32, tf.int32)
(TensorShape(None), TensorShape(None))
<class 'tensorflow.python.framework.ops.EagerTensor'>
<class 'tensorflow.python.framework.ops.EagerTensor'>
<class 'tensorflow.python.framework.ops.EagerTensor'>
<class 'tensorflow.python.framework.ops.EagerTensor'>
<class 'tensorflow.python.framework.ops.EagerTensor'>
<class 'tensorflow.python.framework.ops.EagerTensor'>

使用tf.py_func()输出

(tf.float32, tf.int32)
(TensorShape(None), TensorShape(None))
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>

Reference:

https://github.com/tensorflow/docs/blob/r1.15/site/en/guide/datasets.md

https://blog.csdn.net/poson/article/details/106601539
https://blog.csdn.net/u012073033/article/details/89209894
https://blog.csdn.net/qq_27825451/article/details/105247211

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值