机器学习中,经常会用到one-hot编码。pandas中已经提供了这一函数。
但是这里有一个神坑,得到的one-hot编码数据类型是uint8,进行数值计算时会溢出!!!
import pandas as pd
import numpy as np
a = [1, 2, 3, 1]
one_hot = pd.get_dummies(a)
print(one_hot.dtypes)
print(one_hot)
print(-one_hot)
1 uint8
2 uint8
3 uint8
dtype: object
1 2 3
0 1 0 0
1 0 1 0
2 0 0 1
3 1 0 0
1 2 3
0 255 0 0
1 0 255 0
2 0 0 255
3 255 0 0
正确的做法是,将其转换成浮点:
one_hot = one_hot.astype('float')
print(-one_hot)
1 2 3
0 -1.0 -0.0 -0.0
1 -0.0 -1.0 -0.0
2 -0.0 -0.0 -1.0
3 -1.0 -0.0 -0.0