tensorflow学习1-查找元素

问题描述:

使用tensorflow会话机制,有一个存储index顺序的张量positions,一个存储元素类型的列表types,如果positions中的index对应的types中的元素是VIew就返回1,是Node就返回2。


解决方案:

import tensorflow as tf
# 如果positions中的index对应的types中的元素是VIew就返回1,是Node就返回2
# 假设 positions 是一个表示索引序列的张量
positions = tf.constant([0, 2, 1, 3])

types = ['View', 'Node', 'Node', 'View']

# 假设 shapes 是一个表示形状信息的张量,每一行表示一个长宽的二元组
shapes = tf.constant([[10, 20], [30, 40], [50, 60], [70, 80]])

# 获取 positions 中对应的 types
selected_types = tf.map_fn(lambda idx: tf.gather(types, idx), positions, dtype=tf.string)

# 自定义判断函数,根据 types 的值返回不同的结果
def custom_value(selected_type):
    is_view = tf.equal(selected_type, 'View')
    is_node = tf.equal(selected_type, 'Node')
    return tf.cond(is_view, lambda: 1, lambda: tf.cond(is_node, lambda: 2, lambda: 0))

# 对每个索引进行条件判断
result_values = tf.map_fn(lambda idx: custom_value(tf.gather(types, idx)), positions, dtype=tf.int32)

# 打印结果
with tf.Session() as sess:
    result = sess.run(result_values)
    print("positions 对应的 types 判断结果:", result)



结果:

positions 对应的 types 判断结果: [1 2 2 1]

代码讲解:

1、定义输入数据:
positions 是一个张量,包含了索引序列,代表了你想要进行判断的对象。
types 是一个列表,存储了与 positions 中索引对应的类型信息。
2、构建形状信息张量:
shapes 是一个张量,每一行表示一个长宽的二元组,这里使用张量形式更符合 TensorFlow 的运算。
这样做可以更好地与 TensorFlow 的张量运算方式匹配,使得后续的计算更加便捷。
3、获取类型信息:
tf.map_fn(lambda idx: tf.gather(types, idx), positions, dtype=tf.string) 用于根据 positions 中的索引获取对应的类型信息,并返回一个张量 selected_types,其中每个元素是 types 中对应索引的类型。
4、自定义判断函数:
custom_value 是一个自定义的判断函数,根据类型信息判断是 ‘View’ 还是 ‘Node’,并返回不同的值。在这里,我们使用了 TensorFlow 的条件运算 tf.cond 来实现根据不同条件返回不同的值。
5、应用自定义函数:
tf.map_fn(lambda idx: custom_value(tf.gather(types, idx), shapes), positions, dtype=tf.int32) 用于对 positions 中的每个索引应用自定义函数 custom_value,得到对应的结果值,并返回一个结果张量 result_values。/6、打印结果:
sess.run(result_values) 在 TensorFlow 会话中运行计算图,得到 result_values 的值,即每个索引对应的类型判断结果。
最后打印出这些结果,展示了每个索引对应的类型判断结果。


关键代码:

selected_types = tf.map_fn(lambda idx: tf.gather(types, idx), positions, dtype=tf.string)

tf.map_fn是 TensorFlow 中的一个函数,用于在张量的每个元素上执行相同的操作,并返回一个新的张量。它的作用类似于 Python 中的 map 函数,但是可以应用于 TensorFlow 的计算图中,支持并行处理。

具体来说,tf.map_fn 的语法如下:

tf.map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True, swap_memory=False, infer_shape=True, name=None)
  • fn 是一个函数,用于对 elems 中的每个元素执行操作。
  • elems 是一个张量,是要在其中的每个元素上应用 fn 函数的输入。
  • dtype 是输出张量的数据类型。
  • 其他参数用于控制并行计算等。

在上述代码中,tf.map_fn 用于将 positions 中的每个索引作为参数传递给一个匿名函数,然后通过 tf.gather 函数获取对应索引的 types 中的元素。这样就实现了对每个索引的类型信息进行提取,并返回一个包含这些类型信息的张量 selected_types。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ppdd·~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值