tensorflow的张量维度和形状shape,以及张量中元素的读取

之前一直弄混张量的维度和shape的关系,认为通过tf.shape()获得的就是维度,现在发现错误,记下来。

tf.shap()用来获取的是张量的各个维度上的元素数目。

1                                       #维度为0的标量
[1, 2, 3]                               #维度为1,包含3个元素
[[1, 2], [3, 4]]                        #维度为2, shape=(2, 2)
[[[1, 2], [3, 4]], [[1, 2], [3, 4]]]    #维度为3, shape=(2, 2, 2)
sess = tf.Session()
a = tf.constant([1])
b = tf.constant([2])
c = tf.concat([a, b], axis=0)
sess.run(c)
d = tf.constant([2, 3, 4])
print("a: ", sess.run(tf.shape(a)), a)
print("b: ", sess.run(tf.shape(b)), b)
print("c: ", sess.run(tf.shape(c)), c)
print("d: ", sess.run(tf.shape(d)), d)
e = d[0]
print("e: ", sess.run(tf.shape(e)), e)  #这里表明是标量
f = tf.reshape(e, [1,])
print("f: ", sess.run(tf.shape(f)), f)
g = tf.concat([f, a], axis=0)
print("g: ", sess.run(g), g)
h = d[:1]
print("h: ", sess.run(tf.shape(h)), h)


output:

a:  [1] Tensor("Const_95:0", shape=(1,), dtype=int32)
b:  [1] Tensor("Const_96:0", shape=(1,), dtype=int32)
c:  [2] Tensor("concat_29:0", shape=(2,), dtype=int32)
d:  [3] Tensor("Const_97:0", shape=(3,), dtype=int32)
e:  [] Tensor("strided_slice_39:0", shape=(), dtype=int32)
f:  [1] Tensor("Reshape_6:0", shape=(1,), dtype=int32)
g:  [2 1] Tensor("concat_30:0", shape=(2,), dtype=int32)
h:  [1] Tensor("strided_slice_40:0", shape=(1,), dtype=int32)

 

注意:e的输出为标量,因为这里只是获得其中一个元素

python类似。

张量切片:

i = tf.slice(d, [0],[2]) #d是一维数据,[0]表示从该一维数据的第0个元素开始切片,[2]表示第一维元素保留2个。
print(sess.run(tf.shape(i)), i)
sess.run(i)

output:

[2] Tensor("Slice_9:0", shape=(2,), dtype=int32)
array([2, 3])
elem_tf = tf.constant([i+1 for i in range(30)], shape=[5, 6], name="elem")
sess = tf.Session()
print(sess.run(elem), type(elem_tf))
elem_np = elem_tf.eval(session=sess)  #看这里eval用法
print("\n", elem_np, type(elem_np))
elem_tf_convert = tf.convert_to_tensor(elem_np)
print("\n", sess.run(elem_tf_convert[0][0]), type(elem_tf_convert))
sess.close()

output:

[[ 1  2  3  4  5  6]
 [ 7  8  9 10 11 12]
 [13 14 15 16 17 18]
 [19 20 21 22 23 24]
 [25 26 27 28 29 30]] <class 'tensorflow.python.framework.ops.Tensor'>

 [[ 1  2  3  4  5  6]
 [ 7  8  9 10 11 12]
 [13 14 15 16 17 18]
 [19 20 21 22 23 24]
 [25 26 27 28 29 30]] <class 'numpy.ndarray'>

 1 <class 'tensorflow.python.framework.ops.Tensor'>

注意对于张量元素的读取可以效仿python,直接使用x[0]这样的方式,如下:

sess = tf.Session()
a = tf.constant([1, 2, 3])
#d = tf.add(a[0], a[1])  #或者
d = a[0] + a[1]
print(sess.run(d))
sess.close()

output:

3

 将标量转换成张量用于计算,如下:

sess = tf.Session()
aa = tf.constant([1, 1])
part1 = aa[0]
print(type(part1))
print(tf.shape(part1))
bb = tf.constant([1])
print(tf.shape(bb))
part1 = tf.reshape(part1, [1])  #注意转换维度,不然不能用,这里part1是标量
cc = tf.concat([part1, bb], axis=0)
sess.run(cc)

output:

<class 'tensorflow.python.framework.ops.Tensor'>
Tensor("Shape_95:0", shape=(0,), dtype=int32)
Tensor("Shape_96:0", shape=(1,), dtype=int32)
array([1, 1])

 

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值