感谢您分享您的代码! 您为我们找出了所有困难的东西。 在使用它时,我注意到一些看起来不太正确的小事情。
[FIX#1]轴运动没有像我期望的那样对齐(即,在上面的示例中,您应该能够在所有图上的任意点绘制一条垂直线和水平线,并且这些线应该穿过相应的点 指向其他地块,但现在它不会发生。
[FIX#2]如果您要绘制的变量数量奇数,则右下角的轴不会拉出正确的xtics或ytics。 它只是将其保留为默认的0..1刻度。
这不是一个解决方法,但是我可以选择显式输入main(),以便为对角线变量i设置默认值xi。
在下面,您将找到解决这两个问题的代码更新版本,否则将保留代码的美感。
import itertools
import numpy as np
import matplotlib.pyplot as plt
def scatterplot_matrix(data, names=[], **kwargs):
"""
Plots a scatterplot matrix of subplots. Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names". Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid.
"""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
fig.subplots_adjust(hspace=0.0, wspace=0.0)
for ax in axes.flat:
# Hide all ticks and labels
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
# Set up ticks only on one side for the "edge" subplots...
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')
# Plot the data.
for i, j in zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
# FIX #1: this needed to be changed from ...(data[x], data[y],...)
axes[x,y].plot(data[y], data[x], **kwargs)
# Label the diagonal subplots...
if not names:
names = ['x'+str(i) for i in range(numvars)]
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
# Turn on the proper x or y axes ticks.
for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)
# FIX #2: if numvars is odd, the bottom right corner plot doesn't have the
# correct axes limits, so we pull them from other axes
if numvars%2:
xlimits = axes[0,-1].get_xlim()
ylimits = axes[-1,0].get_ylim()
axes[-1,-1].set_xlim(xlimits)
axes[-1,-1].set_ylim(ylimits)
return fig
if __name__=='__main__':
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
plt.show()
再次感谢您与我们分享。 我已经使用了很多次了! 哦,我重新安排了代码的main()部分,这样,如果将其导入到另一段代码中,则它可以是正式的示例代码,也可以不被调用。