tf.split 与 tf.squeeze 用法

看Char-RNN的代码时遇到这两个函数,记录一下备忘。


先来看一下tf.split()函数,作用是分割tensor,将分割后的tensor放入一个list。

split(
    value,                 # 输入的tensor
    num_or_size_splits,    # 如果是个整数n,就将输入的tensor分为n个子tensor。如果是个list T,就将输入的tensor分为len(T)个子tensor。 
    axis=0,                #  默认为0,表示在哪个维度进行分割。
    num=None,
    name='split'
)

举个栗子。

import tensorflow as tf

a = np.array([[1,2,3],
[4,5,6]])
b = tf.split(a, 3, 1)
c = tf.split(a, [1, 2], 1)
with tf.Session() as sess:
print (sess.run(b))
print (sess.run©)
输出:
[array([[1],
[4]]), array([[2],
[5]]), array([[3],
[6]])]
[array([[1],
[4]]), array([[2, 3],
[5, 6]])]

我们来看一下上面的结果。可以看到输出的 b 是将 a 在第二个维度也就是“列”上将数组平均分为3份,而 c 则是将 a 在“列”维度上将 a 分成两份,每一份的长度对应list里的数值,此处为[1, 2],注意如果num_or_size_splits为一个数,则要分割的那个维度的大小k一定要能被num_or_size_splits整除,上例k=3,num_or_size_splits=3,可以整除,如果num_or_size_splits换为2,则会报错。同理如果num_or_size_splits是一个list,则list里的所有值之和应该等于要分割的那个维度的大小k,上例中1 + 2 = k。

上例中数组a为二维,在高维时同理,这里举一个三维的栗子,深度学习中处理三维的情况比较多。

import tensorflow as tf

a = np.reshape(range(24),(4,2,3))
b = tf.split(a,2,1)
d = tf.split(a,3,2)

with tf.Session() as sess:
print (sess.run(b))
print (sess.run(d))
输出:
[array([[[ 0, 1, 2]],

   [[ 6,  7,  8]],

   [[12, 13, 14]],

   [[18, 19, 20]]]), array([[[ 3,  4,  5]],

   [[ 9, 10, 11]],

   [[15, 16, 17]],

   [[21, 22, 23]]])]

[array([[[ 0],
[ 3]],

   [[ 6],
    [ 9]],

   [[12],
    [15]],

   [[18],
    [21]]]), array([[[ 1],
    [ 4]],

   [[ 7],
    [10]],

   [[13],
    [16]],

   [[19],
    [22]]]), array([[[ 2],
    [ 5]],

   [[ 8],
    [11]],

   [[14],
    [17]],

   [[20],
    [23]]])]</code></pre></div><p>再来看一下 tf.squeeze()函数,作用是去掉维数为1的维度。</p><div class="highlight"><pre><code class="language-text">tf.squeeze

squeeze(
input, # 输入的tensor
axis=None, # 默认为None,去掉维数为1的维度,也可以指定,则去掉指定维度
name=None,
squeeze_dims=None
)

继续用上例举栗。

import tensorflow as tf

a = np.reshape(range(24),(4,2,3))
b = tf.split(a,2,1)
c = tf.squeeze(b[0])
print (c.shape)
输出:
(4, 3)

经过分割后的b是一个list,包含两个数组元素,每个元素的shape都是4 * 1 * 3,将第一个数组元素经过tf.squeeze()函数处理,可以看到维度变为4 * 3,去掉了维数为1的维度。


参考:

【tensorflow 学习】tf.split()和tf.squeeze()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值