Tensorflow(四) —— Tensor的维度变换
1 维度变化主要方式
- 1、shape ndim
- 2、reshape
- 3、expand_dims squeeze
- 4、transpose
- 5、broadcast_to
2 reshape
a = tf.random.normal([4,28,28,3])
print("a:",a.shape,a.ndim)
# 失去图片的行和列信息,可以理解为每个像素点(pixel)
b = tf.reshape(a,[4,28*28,3])
print("b:",b.shape,b.ndim)
# tensor维度转换时,可以指定其中一个值为-1
c = tf.reshape(a,[4,-1,3])
print("c:",c.shape,c.ndim)
# 失去图片的像素点信息,可以理解为data point(数据点)
d = tf.reshape(a,[4,-1])
print("d:",d.shape,d.ndim)
3 reshape is flexible
"""
reshape可以方便的变换维度,但使用时要严格按照其原本内容(content)来理解其具有的视图(view)意义。
视图是否有意义,和其内容顺序密切相关。
"""
a = tf.random.uniform([4,28,28,3], minval=1, maxval=100)
print("a:",a.shape,a.ndim)
# 变换为数据点以后恢复为原来的图片信息
b = tf.reshape(tf.reshape(a,[4,-1]),[4,28,28,3])
print("b:",b.shape,b.ndim)
# 变换后将图片恢复为上下两部分的像素点
c = tf.reshape(tf.reshape(a,[4,-1]),[4,2,-1,3])
print("c:",c.shape,c.ndim)
# 变换后恢复为像素点
d = tf.reshape(tf.reshape(a,[4,-1]),[4,1,-1,3])
print("d:",d.shape,d.ndim)
# 小案例
pl.gray()
pl.figure(figsize=(5,5))
pl.imshow(image[0])
image1 = tf.reshape(image,[60000,-1])
pl.figure(figsize=(5,5))
image2 = tf.reshape(image1,image.shape)
pl.imshow(image2[0])
4 tf.transpose(转置)
"""
可以理解为轴交换,会改变content的内容,即改变content的顺序
"""
a = tf.ones([4,28,28,3])
print("a:",a.shape,a.ndim)
# 若不传参数 则所有位置转置
b = tf.transpose(a)
print("b:",b.shape,b.ndim)
# 交换图片的行和列信息,虽然维度未变化,但是原来的content已改变
c = tf.transpose(a,[0,2,1,3])
print("c:",c.shape,c.ndim)
# 交换rgb通道和列的信息
d = tf.transpose(a,[0,1,3,2])
print("d:",d.shape,d.ndim)
# 小案例
print("image:",image.shape)
pl.figure(figsize=(5,5))
pl.imshow(image[0])
image1 = tf.transpose(image,[0,2,1])
pl.figure(figsize=(5,5))
pl.imshow(image1[0])
5 pytorch 和 tensorflow数据互通实例
"""
tensor图片数据类型为 [b,h,w,c]
torch 图片数据类型为 [b,c,h,w]
"""
a = tf.zeros_like(a)
print("a:",a.shape)
b = tf.transpose(a,[0,3,1,2])
print("b:",b.shape)
6 expand_dims 增加维度
"""
需要在哪个轴添加一个新轴,则指定axis=多少
"""
a = tf.random.normal([4,28,28,3])
print("a:",a.shape,a.ndim)
# 增加一个task维度
b = tf.expand_dims(a,axis = 0)
print("b:",b.shape,b.ndim)
# 末尾增加一个维度
c = tf.expand_dims(a,axis = -1)
print("c:",c.shape,c.ndim)
# 在任意位置增加一个维度
d = tf.expand_dims(a,axis = 4)
print("d:",d.shape,d.ndim)
7 squeeze 删除某个位数为1的轴
a = tf.ones([1,1,4,28,28,1,3,1])
print("a:",a.shape,a.ndim)
# 不指定轴,则删除所有位数为1的轴
b = tf.squeeze(a)
print("b:",b.shape,b.ndim)
# 指定具体的轴,则删除对应的轴
c = tf.squeeze(a,axis = -3)
print("c:",c.shape,c.ndim)
本文为参考龙龙老师的“深度学习与TensorFlow 2入门实战“课程书写的学习笔记
by CyrusMay 2022 04 06