有些数据集给的标签是字符串形式,比如wisdm,在放进网络之前,需要转为数字型的编码
这可以通过pd.Categorical(a).codes实现
如
import numpy as np
import pandas as pd
a = ["standing", "sitting", "jogging", "walking", "upstairs", "downstairs", "standing"]
num_label = pd.Categorical(a).codes
print(num_label)
输出结果是
[3 2 1 5 4 0 3]
顺便提一下如何转为独热编码以及独热编码如何转为普通编码
转独热
a_one_hot = np.asarray(pd.get_dummies(a), dtype=np.int8)
print(a_one_hot)
输出
[[0 0 0 1 0 0]
[0 0 1 0 0 0]
[0 1 0 0 0 0]
[0 0 0 0 0 1]
[0 0 0 0 1 0]
[1 0 0 0 0 0]
[0 0 0 1 0 0]]
pytorch中scatter() 和 scatter_()也可以用来转独热
PyTorch笔记之 scatter() 函数 - 那少年和狗 - 博客园
独热转普通编码
print(np.array([np.argmax(i)for i in a_one_hot]))
# 或者
print(np.argmax(a_one_hot, axis=-1))
输出
[3 2 1 5 4 0 3]
[3 2 1 5 4 0 3]