出现问题的代码,运行时显示索引错误,如下
oneHot=np.identity(3)
for i in range(oneHot.shape[0]):
for j in range(oneHot.shape[1]):
if(oneHot[i,j]==1):
oneHot[i,j]=0.99
else:
oneHot[i,j]=0.01
y_true=oneHot[y_train]
显示索引必须为整数或者布尔型类型的数据,此时我们只要将y_train(浮点型数据)更改为整数类型数据即可
y_true=oneHot[y_train.astype(int)]
这样就可以运行啦