前言
Ragged Tensor(不规则张量) 是一种特殊的多维数据结构,用于处理维度长度不固定的数据(例如变长序列)。与常规张量(所有维度长度相同)不同,Ragged Tensor允许不同维度(如行、列)的元素数量不一致,特别适合处理自然语言、时间序列等变长数据。
应用场景
自然语言处理(NLP):处理变长句子(如分词后的文本:[[“Hello”, “world”], [“Hi”]])。
时间序列分析:不同时间步长的序列(如传感器数据)。
图结构数据:每个节点的邻居数量不同(如社交网络)
示例
基本操作:
import tensorflow as tf
# 创建Ragged Tensor
ragged_tensor = tf.ragged.constant([
[1, 2, 3],
[4],
[5, 6]
])
print(ragged_tensor)
# Output: <tf.RaggedTensor [[1, 2, 3], [4], [5, 6]]>
# 转换为普通张量(需填充统一长度)
padded = ragged_tensor.to_tensor(default_value=0)
print(padded)
# Output: [[1 2 3], [4 0 0], [5 6 0]]
结果如下:
<tf.RaggedTensor [[1, 2, 3], [4], [5, 6]]>
tf.Tensor(
[[1 2 3]
[4 0 0]
[5 6 0]], shape=(3, 3), dtype=int32)
拼接操作
import tensorflow as tf
rt1 = tf.ragged.constant([[1, 2], [3]])
rt2 = tf.ragged.constant([[4], [5, 6]])
# 进行拼接
combined = tf.concat([rt1, rt2], axis=1)
print(combined)
sliced = combined[1:] # 取第二行以后的数据
print(sliced)
结果如下:
<tf.RaggedTensor [[1, 2, 4], [3, 5, 6]]>
<tf.RaggedTensor [[3, 5, 6]]>