机器学习中的维度合并,切分转换,不再傻傻分不清。(numpy 和 tensorflow 中 split,concat 等维度切分合并函数用法详解)

对于机器学习来说,需要在各种维度之间操作,转换,然而对于初学者来说,往往让人难感到很疑惑,之前有大佬的博客以某个矩阵为输入数据,详细演示了各种维度变换,但是缺乏一定的直观性,一堆数字看起来很烦,本篇博客以几张图片为例来讲述numpy,tensorflow等API关于维度转换的函数,转换效果清晰可见。

准备数据

需要准备四张图片,jpg格式就好,转换为224 × 224 的格式大小,图片尺寸不满足也没关系,下面有代码可以将图片转换为对应大小。在项目文件夹目录下建立一个 images 目录,复制四张jpg 图像加入此目录。在images 上级目录下新建一个Python文件。输入以下代码:

import os
from PIL import Image
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
images=os.listdir("images")
image_set=[]
for img in images:
    image=Image.open(os.path.join("images",img))
    image=image.resize((224,224),Image.ANTIALIAS)
    image=np.asarray(image)
    image_set.append(image)
image_set=np.asarray(image_set)
print(image_set.shape)
print(type(image_set))

我们得到的输出如下:

(4, 224, 224, 3)
<class 'numpy.ndarray'>

得到一个 numpy ndarray,维度大小分别是 (4,224,224,3),很好理解,这里的4代表四站图片,两个224 分别代表图片的长和宽,3代表通道数目,实际上,图像有RGB三个通道组成。

我们要对这个 shape 为 (4,224,224,3) 的数组来进行维度转换,通过看图片展示的效果来理解 各种维度转换api 的用法。

numpy.vstack() , numpy.hstack() 以及 numpy.dstack()

这三个函数的用法比较简单,先hstack 的效果图,其它博客解释到 hstack 是在水平方向上拼接维度,输入一下代码,结果如下:

import matplotlib.pyplot as plt
hstack_image=np.hstack(image_set)
plt.imshow(hstack_image)
print(

我们可以看到输出如下的图片和信息:
在这里插入图片描述
直观的理解就是 hstack 函数会讲将4张图片按照行进行拼接,首先是第一张图像的第一行,然后是第二张图像的第一行,第三张图像的第一行,第四张图像的第一行,依次进行拼接,每一行的元素个数是之前的四倍(即列数是之前的四倍),依次拼接每一行,所以行数不会改变,可以看到拼接后的shape 为 (224, 224 × 4,3) ,即将原来的四张图像在一行显示。

知道了hstack之后,就很容易知道 vstack了,这个相当于按照列进行拼接,拼接之后的图片效果如下:

import matplotlib.pyplot as plt
vstack_image=np.vstack(image_set)
plt.imshow(vstack_image)
print(vstack_image.shape)

在这里插入图片描述

至于 dstack ,文档上说是按照深度进行的拼接,即 将通道进行拼接,拼接之后通道数目变为原来的四倍,因为通道拼接之后无法展示,所以下面展示拼接之后的维度。
图片展示第一张图片,即只取出前三个通道 【0-2】。后面【3-5】为第二张图片,【6-8】为第三张图片,【9-11】为第四张图片,可以自行验证。

import matplotlib.pyplot as plt
dstack_image=np.dstack(image_set)
print(dstack_image.shape)
plt.imshow(dstack_image[:,:,:3])

输出如下:
在这里插入图片描述

numpy.concatenate() 函数的用法

前面提到的hstack ,vstack , dstack 可以看做是这个函数的特例。 通过传入axis 参数,即可控制在哪一个维度上进行拼接。

如果 axis 的参数为0 ,效果和 vstack 一致:

import matplotlib.pyplot as plt
dim_0=np.concatenate(image_set,axis=0)
print(dim_0.shape)
plt.imshow(dim_0)

在这里插入图片描述
这里有一种一般的理解,对于拼接来说,一定是降维的过程,有一个维度被取消了,取消啥意思?就是原来的多个被拼接为一个了,以这个 (4,224,224,3) 的四张图片的numpy 数组来说意思就是拼接四张图片为一张图片,我们可以通过观察下面的代码来猜测这个 axis 参数是什么意思:

import matplotlib.pyplot as plt
dim_0=np.concatenate(image_set,axis=0)
dim_1=np.concatenate(image_set,axis=1)
dim_2=np.concatenate(image_set,axis=2)
print(dim_0.shape)
print(dim_1.shape)
print(dim_2.shape)

输出结果如下:

(896, 224, 3)
(224, 896, 3)
(224, 224, 12)

我们可以看到,当axis 为 0的时候,最终结果的第1个维度 是原来的4倍,axis 为1的时候,最终结果的第2个维度为原来的4倍,当axis=2的时候,最终结果的第三个维度为原来的4倍。 所以这样的话 axis 的作用就不言而喻了。

那我们考虑一下,如果想要 hstack 的效果,那么图片的大小应该是 (224,4*224,3) 这样的话行高为一张图片的高度,宽度为四张图片拼接起来的宽度,效果如下:

import matplotlib.pyplot as plt
dim_0=np.concatenate(image_set,axis=0)
dim_1=np.concatenate(image_set,axis=1)
dim_2=np.concatenate(image_set,axis=2)
plt.imshow(dim_1)

在这里插入图片描述

numpy.split()

讲完了维度的合并,现在要来看一下拆分了,同样这个函数有三个参数:

  • array: 要拆分的数组
  • indices_or_sections 要拆分为多少份或者要拆分的起始节点列表。
  • axis 要按照哪一个维度进行拆分

之前我看过别人的博客,甚至有一些大V的博客,都说axis 是按照行划分,按照列划分,这样解释让人摸不着头脑,最好的理解方式就是动手实验。

首先,试验一下axis=0 的情况:
我们先平均分为两份,看这两份分别是什么东西:

import matplotlib.pyplot as plt
split_0,split_1=np.split(image_set,2,axis=0)
print(split_0.shape)
print(split_1.shape)

输出结果:
(2, 224, 224, 3)
(2, 224, 224, 3)

可以看出,当axis 为 0的时候,实际上就是将原来的一个集合(具有四张图片)划分为两个集合(每一个集合有两张图片),那么易得当axis 为 0 的时候就是对 第一维度的划分 原来是(4,224,224,3) 划分为两个(2,224,224,3)。

当axis=1 的时候。依旧划分为两份,代码如下:

import matplotlib.pyplot as plt
split_0,split_1=np.split(image_set,2,axis=1)
print(split_0.shape)
print(split_1.shape)

输出结果
(4, 112, 224, 3)
(4, 112, 224, 3)

可以看到,这里将第二维度拆分为了两个,相当于将每张图片横着切成两半,一个存储着所有图片的上半部分,一个存放着所有图片的下半部分。不信,打印出来给你看:(打印上半部分)

import matplotlib.pyplot as plt
split_0,split_1=np.split(image_set,2,axis=1)
plt.imshow(np.hstack(split_0))

在这里插入图片描述
打印下半部分:

import matplotlib.pyplot as plt
split_0,split_1=np.split(image_set,2,axis=1)
plt.imshow(np.hstack(split_1))

在这里插入图片描述

ok,讲到这里,大家对 split 函数的axis 的作用就很清楚了,当axis =2 的时候,实际上就是将每一张图片竖着切开,所有图片的左半边在一个集合,所有图片的右半边在另外一个集合,打印出来效果如下:

import matplotlib.pyplot as plt
split_0,split_1=np.split(image_set,2,axis=2)
plt.imshow(np.hstack(split_0))

在这里插入图片描述
右半边的图不再放出,节省篇幅~~~

ok,axis 就讲到这里了,下面说说 indices_or_sections,上面的演示中,一直将它设置为 2 ,就是平均分为两份,是几就平均分成几份,但是特别注意的是,这里平均分成几份数要是整份啊,必须除得开才行,否则会报错,这就需要使用另外一个函数了:array_split(),这个函数唯一的区别就是如果出现除不开的情况,就做不均等划分,例如 将4 划分为 3份数,那么划分之后为 【2,1,1】,将5划分为3份,就是【2,2,1】,将 6划分为 4份就是 【2,2,1,1】 这个很好理解。

indices_or_sections除了是一个数字代表平均划分为几份之外,还可以是一个数组,里面的元素必须递增,这个数组表示0为起点,一共几份为终点,在包括起点终点的情况下,元素的间距(我上面说的就不是人话,我自己都听不懂)
举个例子:
例如 我想将 4 张图片划分为 【1,2,1】 ,要怎么搞?均分不可以,array_split() 也不行,此时就可以 设置 indices_or_sections 为:[1,3],请看下面示例代码:

a,b,c=np.split(image_set,[1,3],axis=0)
print(a.shape)
print(b.shape)
print(c.shape)

输出结果:
(1, 224, 224, 3)
(2, 224, 224, 3)
(1, 224, 224, 3)

看到例子,应该就清楚了~

tensorflow 中相关维度变换的函数

tensorflow 中相关的函数和numpy 函数区别不大,下面是一些演示。
为了方便,建立placeholder来进行演示,通过查看placeholder 的维度变换体会函数的用法。

import tensorflow as tf
x=tf.placeholder(tf.float32,shape=(4,224,224,3))
print(x.get_shape())
1. tf.split() 函数

拆分函数,用法请见代码:

import tensorflow as tf
x=tf.placeholder(tf.float32,shape=(4,224,224,3))
one,two=tf.split(x,2,axis=0)
print("one: ",one.get_shape())
print("two: ",two.get_shape())


# 特别注意,这里的第二个参数可以是一个数字,代表平均分成多少份数,这个和numpy 的split没有什么不同
# 但是如果是列表的情况,里面元素的个数就是分成多份,每个元素的值代表对应子集所占的的份额大小,看代码,这不难理解。

three,four,five=tf.split(x,[1,2,1],axis=0)
print("three: ",three.get_shape())
print("four: ",four.get_shape())
print("five: ",five.get_shape())

# 对于 axis 的理解,和 numpy 的 axis 没有什么不同。
seven,eight=tf.split(x,2,axis=1)
print("seven: ",seven.get_shape())
print("egith: ",eight.get_shape())

输出:
one:  (2, 224, 224, 3)
two:  (2, 224, 224, 3)
three:  (1, 224, 224, 3)
four:  (2, 224, 224, 3)
five:  (1, 224, 224, 3)
seven:  (4, 112, 224, 3)
egith:  (4, 112, 224, 3)
2.tf.concat() 函数

合并函数,用法请见代码:

x1=tf.placeholder(tf.float32,shape=(4,224,224,1))
x2=tf.placeholder(tf.float32,shape=(4,224,224,1))
x3=tf.placeholder(tf.float32,shape=(4,224,224,1))

res_0=tf.concat([x1,x2,x3],axis=0)
res_1=tf.concat([x1,x2,x3],axis=1)
res_2=tf.concat([x1,x2,x3],axis=2)
res_3=tf.concat([x1,x2,x3],axis=3)

print("res0:" ,res_0.get_shape())
print("res1:" ,res_1.get_shape())
print("res2:" ,res_2.get_shape())
print("res3:" ,res_3.get_shape())

输出:
res0: (12, 224, 224, 1)
res1: (4, 672, 224, 1)
res2: (4, 224, 672, 1)
res3: (4, 224, 224, 3)

不解释了,大家都懂。


  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值