一、concat(拼接)
concat()是将tensor沿着指定维度连接起来。
tf.concat(values, axis, name='concat')
- values:A list of Tensor objects or a single Tensor.
- axis:0-D int32 Tensor. Dimension along which to concatenate. Must be in the range [-rank(values), rank(values)). As in Python, indexing for axis is 0-based. Positive axis in the rage of [0, rank(values)) refers to axis-th dimension. And negative axis refers to axis + rank(values)-th dimension.
- axis=0 代表在第0个维度拼接
- axis=1 代表在第1个维度拼接
- name:A name for the operation (optional).
- Returns:A Tensor resulting from concatenation of the input tensors.
1、axis=0 代表在第0个维度拼接
import tensorflow as tf
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
data1 = tf.concat([t1, t2], 0)
print("data1 = \n", data1)
打印结果:
data1 =
tf.Tensor(
[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]
[10 11 12]], shape=(4, 3), dtype=int32)
2、axis=1 代表在第1个维度拼接
import tensorflow as tf
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
data2 = tf.concat([t1, t2], 1)
print("data2 = \n", data2)
打印结果:
data2 =
tf.Tensor(
[[ 1 2 3 7 8 9]
[ 4 5 6 10 11 12]], shape=(2, 6), dtype=int32)
二、stack
tf.stack(values, axis=0, name='stack')
- values:list中的所有Tensor的形状和类型必须一致。(A list of Tensor objects with the same shape and the same type.)
- axis:An int. The axis to stack along. Defaults to the first dimension. Negative values wrap around, so the valid range is [-(R+1), R+1).
- name:A name for this operation (optional).
- Returns:output A stacked Tensor with the same type as values.
1、axis=0 代表在第0个维度堆叠
import tensorflow as tf
x = tf.constant([1, 4])
y = tf.constant([2, 5])
z = tf.constant([3, 6])
data0 = tf.stack([x, y, z])
print("x = \n", x)
print("-" * 50)
print("y = \n", y)
print("-" * 50)
print("z = \n", z)
print("-" * 200)
print("data0 = \n", data0)
打印结果:
x =
tf.Tensor([1 4], shape=(2,), dtype=int32)
--------------------------------------------------
y =
tf.Tensor([2 5], shape=(2,), dtype=int32)
--------------------------------------------------
z =
tf.Tensor([3 6], shape=(2,), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
data0 =
tf.Tensor(
[[1 4]
[2 5]
[3 6]], shape=(3, 2), dtype=int32)
Process finished with exit code 0
2、axis=1 代表在第1个维度堆叠
import tensorflow as tf
x = tf.constant([1, 4])
y = tf.constant([2, 5])
z = tf.constant([3, 6])
data1 = tf.stack([x, y, z], axis=1)
print("x = \n", x)
print("-" * 50)
print("y = \n", y)
print("-" * 50)
print("z = \n", z)
print("-" * 200)
print("data1 = \n", data1)
打印结果:
x =
tf.Tensor([1 4], shape=(2,), dtype=int32)
--------------------------------------------------
y =
tf.Tensor([2 5], shape=(2,), dtype=int32)
--------------------------------------------------
z =
tf.Tensor([3 6], shape=(2,), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
data1 =
tf.Tensor(
[[1 2 3]
[4 5 6]], shape=(2, 3), dtype=int32)
Process finished with exit code 0
参考资料:
Tensorflow的拼接操作实现——tf.concat()详解
tensorflow中的concat
tensorflow stack unstack操作
TensorFlow函数:tf.unstack