一、获取维度
1.1 tf.shape()
1.1.1 函数信息
tf.shape(
input, name=None, out_type=tf.dtypes.int32
)
- out_type:
返回的结果是哪种类型的Tensor。
注意:返回的结果是Tensor,在喂入数据前不可知
1.1.2 例子
input = tf.placehold([None, 64, 128], dtype=tf.float)
input = tf.reshape(input, [-1, tf.shape(input)[1], 4, 32])
# 这样input的shape就是 [None, None, 4, 32]了。
scale_value = tf.math.sqrt(tf.shape(input, tf.float32)[1]) # 这样会报错
1.2 get_shape().as_list()
Tensor类型,可以调用该函数得到各个维度信息。
1.2.1 函数信息
返回的结果是List。
1.2.2 代码
Q = tf.placeholder(shape=[None, 32, 32], dtype=tf.float32)
print(Q.get_shape().as_list)
结果是[None, 32, 32]
二、更改维度
2.1 tf.reshape()
2.1.1 函数信息
tf.reshape(
tensor, shape, name=None
)
- shape
A Tensor. Must be one of the following types: int32, int64.
2.1.2 代码
例一:
import numpy as np
import tensorflow as tf
input = tf.placeholder(shape=[None, 32, 32], dtype=tf.float32)
# 这里是没问题的
output1 = tf.reshape(input, tf.shape(input)[-1])
# 这里是有问题的
output2 = tf.reshape(input, tf.shape(input)[0])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print_output1, print_output2 =
sess.run([output1, output2], feed_dict={input: np.random.normal([16, 32, 32]) })
注意:因为tf.shape(input)[0]是None, 所以用来初始化reshape是有问题的。
例二:
x = tf.constant([[1, 2, 3, 4, 5, 6],
[8, 7, 6, 5, 4, 3]]) # [2, 4]
x = tf.reshape(x, [2, 3, 2])
# [[[1, 2],
[3, 4],
[5, 6]],
[[8, 7],
[6, 5],
[4, 3]]]
可以看出,最后一维是dim维。
2.2 tf.transpose()
2.2.1 函数信息
tf.transpose(
a, perm=None, name='transpose', conjugate=False
)
- perm:
如果没指定,等价于perm = [n-1, n-2 ,n-3, n-4]
2.2.2 代码
例一:
x = tf.constant([[1, 2, 3], [4, 5, 6]])
tf.transpose(x, perm=[1, 0])
# [[1, 4]
# [2, 5]
# [3, 6]]
其实transpose就是转置:
因为perm = [1, 0], 原始维度是[2, 3],所以经过transpose后变成[3, 2]。
第一行变成第一列,第二行变成第二列。
例二:
x = tf.constant([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
tf.transpose(x, perm=[0, 2, 1])
# [[[1, 4],
# [2, 5],
# [3, 6]],
# [[7, 10],
# [8, 11],
# [9, 12]]]
转置的是后两维, 所以只看最后两维即可。
先看第一维中的第一个
[[1, 2, 3], [[1, 4],
[4, 5, 6]] 变为 [2, 5],
[3, 6]]
第一维中的第二个同理
[[ 7, 8, 9], [[7, 10],
[10, 11, 12]] 变为 [8, 11],
[9, 12]]
例三:
x = tf.constant([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]]) # [2, 2, 3]
tf.transpose(x, perm=[2, 0, 1])
假设原始x最后一维是multi-head(其实一般最后一维都是dimension)。先忽略,并且降一维度。[2, 2, 3] -> [2, 2]
因为多头,最后一维有3列,所以可以分为3个[2, 2]
[[1, 4], [[2, 5], [[3, 6],
[7, 10]] [8, 11]] [9, 12]]
然后再在第0维进行拼接, 然后就是最后的结果。
[[[1, 4],
[7, 10]],
[[2, 5],
[8, 11]],
[[3, 6],
[9, 12]]]
例四:
x = tf.constant([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]]) # [2, 2, 3]
tf.transpose(x, perm=[1, 0, 2])
最后一个维度没有变:将[1, 2, 3]当作一个整体,降维:
[[1, 4],
[7, 10]]
然后因为是perm = [1, 0], 所以:
[[1, 7],
[4, 10]]
再将之前整合后的整体拆分出来, 这样就得到最后结果了。
[[[1, 2, 3],
[7, 8, 9]],
[[4, 5, 6],
[10, 11, 12]]]
2.3 tf.split()
2.3.1 函数信息
tf.split(
value, num_or_size_splits, axis=0, num=None, name='split'
)
- num_or_size_splits:
切分次数 - axis:
切分的维度