参考资料:https://segmentfault.com/a/1190000008793389
https://www.jianshu.com/p/ad88a0afa98f
#Tensorflow数据操作常用函数总结
##一、张量变换
###1.1 数据抽取
tf.slice(input_, begin, size, name=None)
:按照指定的下标范围抽取连续区域的子集tf.gather(params, indices, validate_indices=None, name=None)
:按照指定的下标集合从axis=0
中抽取子集,适合抽取不连续区域的子集
input = [[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]]
tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]
tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],
[4, 4, 4]]]
tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
[[5, 5, 5]]]
tf.gather(input, [0, 2]) ==> [[[1, 1, 1], [2, 2, 2]],
[[5, 5, 5], [6, 6, 6]]]
tf.nn.embedding_lookup(params, ids, partition_strategy='mod', name=None, validate_indices=True, max_norm=None)
选取一个张量里面索引对应的元素
input =np.array( [[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]])
idx1 = tf.Variable([0, 2, 3], tf.int32)
idx2 = tf.Variable([[0, 2, 3], [3, 0, 2]], tf.int32)
out1 = tf.nn.embedding_lookup(input, idx1) ==> [[ 1 1 1], [ 3 3 3], [ 4 4 4]]
out2 = tf.nn.embedding_lookup(input, idx2) ==> [[[ 1 1 1], [ 3 3 3], [ 4 4 4]],
[[ 4 4 4], [ 1 1 1], [ 3 3 3]]]
###1.2 形状变化
tf.reshape(tensor, shape, name=None)
转换成新的shape
,若有一个维度设置为-1,会自动推导。
###1.3 张量扩张tf.title(input,multiples,name=None)
将tensor在指定维度进行复制。
a = [1,2,3,4]
b = tf.reshape(a, [4,1])
c = tf.tile(b, [1,3]) ==>[[1,2,3,4],[1,2,3,4],[1,2,3,4]]