在散点图中,一般使用一维的色彩映射显示散点的某个属性。 @Pjer 大佬使用2D的colormap实现了“在同一张2d图里面显示出所有对应点的2种属性”,并指出了“从HSV色彩空间到RGB色彩空间的映射”这一思路。大佬的回答如下:
科研画图都用什么软件?www.zhihu.com本文中,我使用Python复现了这一功能,并完善成一个方便使用的函数colorbar2d,在这里做简要介绍并提供源码。
colorbar2d 可选择输入以下参数:
- list1:(必须)一维数组,将转换为色相信息;
- list2:(必须)一维数组,将转换为明度信息;
- minColor:字符串,表示起始色相,默认minColor = None(注);
- maxColor:字符串,表示终止色相,默认minColor = None(注);
- maxv:0-1之间的浮点数或整数,表示最暗的明度,默认maxv = 1,即纯黑色;
- s:0-1之间的浮点数或整数,表示饱和度,默认s = 1,即饱和度最大;
- step:colormap的绘制方式为矩阵排列的散点叠加而成,该参数用于控制散点的间隔;默认step = 0.05;
colorbar2d 实例化后,可选择返回以下输出:
- colorbar2d.rgb():返回 n×3 数组,n 为输入的一维数组的长度值,即样本数;3 为R、G、B信息;
- colorbar2d.hsv():返回 n×3 数组,即各样本的H、S、V信息(不常用);
- colorbar2d.colorbar(),返回两个一维数组(x、y轴散点坐标)和一个 n×3 数组(散点的RGB信息);上述信息用于绘制二维colorbar;
- 注:颜色可选择红(red,r)、橙(orange,o)、黄(yellow,y)、绿(green,g)、青(cyan,c)、蓝(blue,b)、紫(purple,m);若输入中没有颜色范围指示,colorbar默认绘制从红至紫的所有色相。
colorbar2d 源码如下:
from matplotlib import colors
import numpy as np
class colorbar2d:
def __init__(self, list1, list2, minColor = None, maxColor = None, maxv = 1, s = 1, step = 0.05):
self.list1_max, self.list1_min = max(list1), min(list1)
self.list2_max, self.list2_min = max(list2), min(list2)
self.maxv = maxv
self.maxs = s
# 将所选特征0-1化
if minColor == None and maxColor == None:
self.h = (list1 - self.list1_min) / (self.list1_max - self.list1_min)
self.limit_color = False
else:
color_dic = {'red':0, 'orange':18, 'yellow':30, 'green':56,
'cyan':88, 'blue':112, 'purple':140, 'r':0,
'o':18, 'y':30, 'g':56, 'c':88,
'b':112, 'm':140}
self.mincolor = color_dic.get(minColor)/180.0
self.maxcolor = color_dic.get(maxColor)/180.0
self.h = (list1 - self.list1_min) / (self.list1_max - self.list1_min) * (self.maxcolor - self.mincolor) + self.mincolor
self.limit_color = True
self.v = (list2 - self.list2_min) / (self.list2_max - self.list2_min) * self.maxv + (1- self.maxv)
self.s = [self.maxs for i in range(len(list1))]
self.hsv = np.zeros(shape=(len(list1),3))
self.hsv[:,0] = self.h
self.hsv[:,1] = self.s
self.hsv[:,2] = self.v
def rgb(self):
self.rgb = colors.hsv_to_rgb(self.hsv)
return self.rgb
def hsv(self):
return self.hsv
def colorbar(self):
h = 0.05
xx, yy = np.meshgrid(np.arange(self.list1_min, self.list1_max, h),
np.arange(self.list2_min, self.list2_max, h))
xx, yy = xx.ravel(), yy.ravel()
if self.limit_color == False:
colorbar_h = (xx - self.list1_min) / (self.list1_max - self.list1_min)
colorbar_v = (yy - self.list2_min) / (self.list2_max - self.list2_min)
else:
colorbar_h = (xx - self.list1_min) / (self.list1_max - self.list1_min) * (self.maxcolor - self.mincolor) + self.mincolor
colorbar_v = (yy - self.list2_min) / (self.list2_max - self.list2_min) * self.maxv + (1- self.maxv)
colorbar_s = [self.maxs for i in range(len(xx))]
hsv = np.zeros(shape=(len(xx),3))
hsv[:,0] = colorbar_h
hsv[:,1] = colorbar_s
hsv[:,2] = colorbar_v
rgb = colors.hsv_to_rgb(hsv)
return xx, yy, rgb
以鸢尾花数据集为例绘制题图。该数据集中每个样本(鸢尾花)有四个属性:花萼长度(Sepal Length),花萼宽度(Sepal Width),花瓣长度(Petal Length),花瓣宽度(Petal Width)。数据集共150个样本,包含三种鸢尾花:Setosa(山鸢尾)、Versicolour(杂色鸢尾),Virginica(维吉尼亚鸢尾)。
我们选择Sepal Length、Sepal Width以散点图的x、y轴表示,Petal Length、Petal Width分别以colorbar颜色属性中的色相、明度表示。代码及结果如下:
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
iris = load_iris()
mycolorbar = colorbar2d(iris.data[:,2], iris.data[:,3], 'b', 'r', maxv = 0.8)
mycolorbar_rgb = mycolorbar.rgb()
colorbar_x, colorbar_y, xy_color = mycolorbar.colorbar()
fig = plt.figure(figsize=(8,6))
fig.subplots_adjust(wspace = 0.5, hspace = 0.5)
plt.subplot2grid((1,4),(0,0),colspan=3)
plt.scatter(iris.data[:50,0], iris.data[:50,1],
s = 30, marker = '^', c = mycolorbar_rgb[:50,:],
label = 'Setosa')
plt.scatter(iris.data[50:100,0], iris.data[50:100,1],
s = 30, marker = 'o', c = mycolorbar_rgb[50:100,:],
label = 'Versicolour')
plt.scatter(iris.data[100:,0], iris.data[100:,1],
s = 30, marker = '*', c = mycolorbar_rgb[100:,:],
label = 'Virginica')
plt.xlabel("sepal length")
plt.ylabel("sepal width")
plt.title("sepal length and width scatter")
plt.legend(loc = "upper right")
plt.subplot2grid((1,4),(0,3))
plt.scatter(colorbar_y, colorbar_x, c = xy_color)
plt.xlim(colorbar_y.min(), colorbar_y.max())
plt.ylim(colorbar_x.min(), colorbar_x.max())
plt.xlabel("petal width")
plt.ylabel("petal length")
plt.title('colormap')
plt.show()
在此基础上,可以在colormap中添加核密度估计图用以指示样本属性在colormap上的分布。代码及结果如下:
def get_scatter(x, y):
xmin, xmax = x.min(), x.max()
ymin, ymax = y.min(), y.max()
xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([xx.ravel(), yy.ravel()])
values = np.vstack([x, y])
kernel = st.gaussian_kde(values)
f = np.reshape(kernel(positions).T, xx.shape)
return xx, yy, f
fig = plt.figure(figsize=(8,6))
fig.subplots_adjust(wspace = 0.5, hspace = 0.5)
plt.subplot2grid((1,4),(0,0),colspan=3)
plt.scatter(iris.data[:50,0], iris.data[:50,1],
s = 30, marker = '^', c = mycolorbar_rgb[:50,:],
label = 'Setosa')
plt.scatter(iris.data[50:100,0], iris.data[50:100,1],
s = 30, marker = 'o', c = mycolorbar_rgb[50:100,:],
label = 'Versicolour')
plt.scatter(iris.data[100:,0], iris.data[100:,1],
s = 30, marker = '*', c = mycolorbar_rgb[100:,:],
label = 'Virginica')
plt.xlabel("sepal length")
plt.ylabel("sepal width")
plt.title("sepal length and width scatter")
plt.legend(loc = "upper right")
plt.grid()
plt.subplot2grid((1,4),(0,3))
# 绘制三种鸢尾花样本的分布曲线
xx1, yy1, f1 = get_scatter(iris.data[:50,3], iris.data[:50,2])
plt.contour(xx1, yy1, f1, 3, colors='w', linewidths = 1)
xx2, yy2, f2 = get_scatter(iris.data[50:100,3], iris.data[50:100,2])
plt.contour(xx2, yy2, f2, 3, colors='w', linewidths = 1)
xx3, yy3, f3 = get_scatter(iris.data[100:,3], iris.data[100:,2])
plt.contour(xx3, yy3, f3, 3, colors='w', linewidths = 1)
plt.scatter(colorbar_y, colorbar_x, c = xy_color)
plt.xlim(colorbar_y.min(), colorbar_y.max())
plt.ylim(colorbar_x.min(), colorbar_x.max())
plt.xlabel("petal width")
plt.ylabel("petal length")
plt.title('color bar')
plt.savefig("temp.png",dpi=300,bbox_inches = 'tight')
plt.show()