看Char-RNN的代码时遇到这两个函数,记录一下备忘。
先来看一下tf.split()函数,作用是分割tensor,将分割后的tensor放入一个list。
split(
value, # 输入的tensor
num_or_size_splits, # 如果是个整数n,就将输入的tensor分为n个子tensor。如果是个list T,就将输入的tensor分为len(T)个子tensor。
axis=0, # 默认为0,表示在哪个维度进行分割。
num=None,
name='split'
)
举个栗子。
import tensorflow as tf
a = np.array([[1,2,3],
[4,5,6]])
b = tf.split(a, 3, 1)
c = tf.split(a, [1, 2], 1)
with tf.Session() as sess:
print (sess.run(b))
print (sess.run©)
输出:
[array([[1],
[4]]), array([[2],
[5]]), array([[3],
[6]])]
[array([[1],
[4]]), array([[2, 3],
[5, 6]])]
我们来看一下上面的结果。可以看到输出的 b 是将 a 在第二个维度也就是“列”上将数组平均分为3份,而 c 则是将 a 在“列”维度上将 a 分成两份,每一份的长度对应list里的数值,此处为[1, 2],注意如果num_or_size_splits为一个数,则要分割的那个维度的大小k一定要能被num_or_size_splits整除,上例k=3,num_or_size_splits=3,可以整除,如果num_or_size_splits换为2,则会报错。同理如果num_or_size_splits是一个list,则list里的所有值之和应该等于要分割的那个维度的大小k,上例中1 + 2 = k。
上例中数组a为二维,在高维时同理,这里举一个三维的栗子,深度学习中处理三维的情况比较多。
import tensorflow as tf
a = np.reshape(range(24),(4,2,3))
b = tf.split(a,2,1)
d = tf.split(a,3,2)
with tf.Session() as sess:
print (sess.run(b))
print (sess.run(d))
输出:
[array([[[ 0, 1, 2]],
[[ 6, 7, 8]],
[[12, 13, 14]],
[[18, 19, 20]]]), array([[[ 3, 4, 5]],
[[ 9, 10, 11]],
[[15, 16, 17]],
[[21, 22, 23]]])]
[array([[[ 0],
[ 3]],
[[ 6],
[ 9]],
[[12],
[15]],
[[18],
[21]]]), array([[[ 1],
[ 4]],
[[ 7],
[10]],
[[13],
[16]],
[[19],
[22]]]), array([[[ 2],
[ 5]],
[[ 8],
[11]],
[[14],
[17]],
[[20],
[23]]])]</code></pre></div><p>再来看一下 tf.squeeze()函数,作用是去掉维数为1的维度。</p><div class="highlight"><pre><code class="language-text">tf.squeeze
squeeze(
input, # 输入的tensor
axis=None, # 默认为None,去掉维数为1的维度,也可以指定,则去掉指定维度
name=None,
squeeze_dims=None
)
继续用上例举栗。
import tensorflow as tf
a = np.reshape(range(24),(4,2,3))
b = tf.split(a,2,1)
c = tf.squeeze(b[0])
print (c.shape)
输出:
(4, 3)
经过分割后的b是一个list,包含两个数组元素,每个元素的shape都是4 * 1 * 3,将第一个数组元素经过tf.squeeze()函数处理,可以看到维度变为4 * 3,去掉了维数为1的维度。
参考: