最终效果为:
"""
红色条RGB范围:(0, 0, 0)~(1.0, 0, 0)
绿色条RGB范围:(0, 0, 0)~(0, 1.0, 0)
蓝色条RGB范围:(0, 0, 0)~(0, 0, 1.0)
"""
# https://zhuanlan.zhihu.com/p/260467222
def draw3D(x,y,z,x_label,y_label,z_label,picture_name:str,map="tab10"):
"""
:param x: 二维
:param y: 二维
:param z: 二维
:param x_label: x
:param y_label:
:param z_label:
:param picture_name:
:return:
"""
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt
# 设置图例字号
mpl.rcParams['legend.fontsize'] = 10
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.gca(projection='3d')
z_ravel = z.ravel()
min_score = min(z_ravel)
max_score = max(z_ravel)
# 颜色 通过cmap获取 值越大 颜色越深
z1 =[]
num = 10000
for i in range(len(x)):
p = []
for j in range(len(x[0])):
value = (((max_score-z[i,j])/(max_score-min_score)+ 0.5)/1.5)**0.5
value = (((z[i,j]-min_score)/(max_score-min_score)+0.05)/1.1)**0.6
# cmap
# p.append(plt.get_cmap(map,num)(int((max_score-z[i,j])/(max_score-min_score)* num)))
# 红色渐变色
# p.append(((max_score-z[i,j])/(max_score-min_score),0,0))
# 蓝色渐变色
p.append((0,value,value))
z1.append(p)
z1 = np.array(z1)
# 横坐标线
for i in range(len(x)):
for j in range(len(x[0])-1):
# (i,j) -> (i,j+1) 的连线
x_ = [x[i,j],x[i,j+1]]
y_ = [y[i,j],y[i,j+1]]
z_ = [z[i,j],z[i,j+1]]
plt.plot(x_,y_,z_,color=z1[i,j])
# 纵坐标线
for i in range(len(x)-1):
for j in range(len(x[0])):
# (i,j) -> (i+1,j) 的连线
x_ = [x[i, j], x[i+1, j]]
y_ = [y[i, j], y[i+1, j ]]
z_ = [z[i, j], z[i+1, j ]]
plt.plot(x_, y_, z_, color=z1[i, j])
circel = True
for i in range(len(x)):
for j in range(len(x[0])):
if min_score == z[i,j]:
ax.scatter(x[i],x[j],min_score,color="red",s=4,)
circel=False
break
if circel is False:
break
ax.set_zlim(min_score, max_score)
ax.set_xlim(-8, 15)
ax.set_ylim(-8, 10)
ax.set_zlabel(z_label, fontdict={'size': 15, 'color': 'black'})
ax.set_ylabel(y_label, fontdict={'size': 15, 'color': 'black'})
ax.set_xlabel(x_label, fontdict={'size': 15, 'color': 'black'})
# 标签生效
# plt.legend()
plt.savefig(get_log_name(picture_name,"png","./SVR_grid/map"),figsize=(12, 9), dpi=600)
plt.show()