scatter绘制三维图形报错:ValueError: Invalid RGBA argument
代码如下
fig = plt.figure()
ax = plt.axes(projection='3d')
c1=list(np.reshape(Y_train,(50,)))
ax.scatter(X_train[:,2],X_train[:,1], X_train[:,0],c=c1) #报错代码
plt.show()
其中X_train为
报错内容如下:
masked_array(data=[0.267004 +0.j, 0.004874 +0.j, 0.329415 +0.j,
0.50135356+0.j],
mask=False,
fill_value=(1e+20+0j))
<Figure size 432x288 with 1 Axes>
找了其他的博客,了解到scatter的第三个参数需要为数值型,即X_train[:,2]这一维度需要为实数,而这里是复数。
使用np.real()修改报错代码即可:
ax.scatter(X_train[:,2],X_train[:,1], np.real(X_train[:,0]),c=c1)