前言
在构建分类算法的时候,标签通常都要求是one_hot编码,实际上标签可能都是整数,所以我们都需要将整数转成one_hot编码,本篇文章主要介绍如何利用numpy快速将整数转成one_hot编码
。
代码示例
在使用numpy生成one hot编码的时候,需要使用numpy中的一个eye函数
,先简单介绍一下这个函数的功能。
函数
:np.eye(N, M=None, k=0, dtype=float, order=‘C’)
功能说明
:用来返回一个2维的对角数组
参数
:
- N:用来控制输出二维数组的行数
- M:用来控制输出二维数组的列数,如果M为None,则M等于N
- k:主对角线的index,默认是0,如果k为正数,则对角线往上移动,如果k为负数,则对角线往下移动
1. N和M相等的时候
print(np.eye(5))
2. N和M不相等
print(np.eye(5,4))
3. k不为0
print(np.eye(5,5,k=1))
print(np.eye(5,5,k=-1))
生成one hot编码
#设置类别的数量
num_classes = 10
#需要转换的整数
arr = [1,3,4,5,9]
#将整数转为一个10位的one hot编码
print(np.eye(10)[arr])