学习TensrFlow 2 的随笔(四)Dataset.map()的使用

上一篇初步说了Dataset中的一些问题,这里还要记录一下Dataset.map()中的一些特别容易出问题的东西。学习TensrFlow 2 的随笔(三)tf.data.Dataset

  • 1、怎么在映射函数中进行tf.Tensor的类型转换?
    在映射函数里,一般都要求进行TensorFlow的操作。此时如果想将Tensor的类型进行转换,比如想将bool类型转为float。这样就可以直接将掩膜转为可以计算的Tensor.
    mask=tf.image.convert_image_dtype(mask,tf.float32)
    一般进行数据类型的转换,这个函数可以进行。但是它只可以进行以下类型间的转换:bfloat16, half, float, double, uint8, int8, uint16, int16, int32, uint32, uint64, int64, complex64, complex128。可以发现没有bool,所以不可以。
    mask=tf.cast(mask,tf.float32)
    这个函数才可以实现转换,因为它有bool类型。可将将True转为1,False转为0输出。
  • 2、映射函数中就是需要进行python的逻辑计算该怎么办?
    答案:使用tf.py_function()函数
    比如想将图片进行任意角度旋转,这个在tf.image没法进行。那就采用python函数来进行。例子:
    首先定义一个图像旋转函数,这个函数是纯python的,跟TensorFlow运算没有任何关系。
import scipy.ndimage as ndimage
def random_rotate_image(image):
  image = ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)
  return image

然后,利用tf.py_function()把这个函数包起来,放在一个映射函数里。特别注意的是使用这个函数时要明确表示出返回的形状shapes和类型types。原因和上一篇里说的一样,tf.Graph需要边缘。一定要对返回进行shape定义,就是set_shape(tf.TensorShape([None,None,3]))这一步一定要有。

def tf_random_rotate_image(image, label):
  im_shape = image.shape
  [image,] = tf.py_function(random_rotate_image, [image], [tf.float32])
  image.set_shape(im_shape)
  return image, label

最后将映射函数放入到map中:

rot_ds = images_ds.map(tf_random_rotate_image)
for image, label in rot_ds.take(2):
  show(image, label)

参考:
1.tf.data
2.tf.py_function

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值