【tensorflow】tf.sparse_split用法——使用tf.sparse_split拆分sparse_tensor

我们在实际tensorflow应用中,如果遇到保存稀疏矩阵的时候,会选择Sparse_tensor,这样可以节省大量的空间。
但是如果想要拆分稀疏矩阵的时候,直观的思路是:先将spare_tensor转为dense_tensor,然后拆分,然后再转成spare_tensor,这个过程中耗时不说,专程dense实际上就违背了我们节省空间的初衷。

  • 正确的解决方式是:

def sparse_split(keyword_required=KeywordRequired(),
                 sp_input=None,
                 num_split=None,
                 axis=None,
                 name=None,
                 split_dim=None):
  """Split a `SparseTensor` into `num_split` tensors along `axis`.

  If the `sp_input.dense_shape[axis]` is not an integer multiple of `num_split`
  each slice starting from 0:`shape[axis] % num_split` gets extra one
  dimension. For example, if `axis = 1` and `num_split = 2` and the
  input is:

      input_tensor = shape = [2, 7]
      [    a   d e  ]
      [b c          ]

  Graphically the output tensors are:

      output_tensor[0] =
      [    a ]
      [b c   ]

      output_tensor[1] =
      [ d e  ]
      [      ]

  Args:
    keyword_required: Python 2 standin for * (temporary for argument reorder)
    sp_input: The `SparseTensor` to split.
    num_split: A Python integer. The number of ways to split.
    axis: A 0-D `int32` `Tensor`. The dimension along which to split.
    name: A name for the operation (optional).
    split_dim: Deprecated old name for axis.

  Returns:
    `num_split` `SparseTensor` objects resulting from splitting `value`.

  Raises:
    TypeError: If `sp_input` is not a `SparseTensor`.
    ValueError: If the deprecated `split_dim` and `axis` are both non None.
  """
  • 举例说明用法:
import tensorflow as tf

a = tf.SparseTensor(indices=[[0,0],[1,1]],values=[1,2],dense_shape=(2,2))

b,c = tf.sparse_split(sp_input=a,num_split=2,axis=1)

with tf.Session() as sess:
    print(sess.run(a))
    print(sess.run(b))

输出是:

SparseTensorValue(indices=array([[0, 0],
       [1, 1]]), values=array([1, 2], dtype=int32), dense_shape=array([2, 2]))
SparseTensorValue(indices=array([[0, 0]]), values=array([1], dtype=int32), dense_shape=array([2, 1]))
SparseTensorValue(indices=array([[1, 0]]), values=array([2], dtype=int32), dense_shape=array([2, 1]))
  • 需要注意的是:
    上面是使用python3的版本,如果使用python2,必须传入keyword_required参数,否则会报错:

      Keyword arguments are required for this function
    

python2的调用方法为:

from tensorflow.python.ops.sparse_ops import KeywordRequired
import tensorflow as tf

a = tf.SparseTensor(indices=[[0,0],[1,1]],values=[1,2],dense_shape=(2,2))

b,c = tf.sparse_split(keyword_required=KeywordRequired(),sp_input=a,num_split=2,axis=1)

with tf.Session() as sess:
    print(sess.run(a))
    print(sess.run(b))
    print(sess.run(c))

这样就能解决报错的问题。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值