效果图:
参考链接:
class Pairs(object):
def __init__(self, pdata: pd.DataFrame,corr_matrix, p_matrix, figsize: tuple = (30, 30), dpi: int = 300):
pdata = pdata.select_dtypes(include=np.number)
self.figsize = figsize
self.dpi = dpi
self.rawdata = pdata
self.num_rol = self.rawdata.shape[1]
self.data_col_name = self.rawdata.columns.tolist()
self.corr_matrix = corr_matrix
self.p_matrix = p_matrix
def pairs(self):
fig, ax = plt.subplots(ncols=self.num_rol,nrows=self.num_rol,figsize=self.figsize,dpi=self.dpi)
for temp_col in range(self.num_rol):
for temp_row in range(self.num_rol):
if temp_row == temp_col:
# plot 主对角线
sns.histplot(data=self.rawdata, x=self.data_col_name[temp_col], bins=20, ax=ax[temp_row, temp_col])
ax[temp_row, temp_col].set_ylabel('lv%', fontsize=8,rotation=0, labelpad=10)
ax[temp_row, temp_col].set_xlabel('lv%', fontsize=8,rotation=90)
if temp_row > temp_col:
# 下对角线
sns.scatterplot(data=self.rawdata,x=self.data_col_name[temp_col],y=self.data_col_name[temp_row],ax=ax[temp_row, temp_col], s=5)
ax[temp_row, temp_col].set_ylabel(self.data_col_name[temp_row], fontsize=8, rotation=0, labelpad=20)
ax[temp_row, temp_col].set_xlabel(self.data_col_name[temp_col], fontsize=8, rotation=90)
elif temp_row < temp_col:
#上对角线
ax[temp_row, temp_col].set_axis_off()
im = ax[temp_row, temp_col].imshow(np.array([[corr_matrix.iloc[temp_row, temp_col]]]), cmap='coolwarm', vmin=-0.1, vmax=1)
p_value_text = f'{corr_matrix.iloc[temp_row, temp_col]:.2f}'
ax[temp_row, temp_col].text(0.5, 0.5, p_value_text, ha='center', va='center', color='white', fontsize=14, transform=ax[temp_row, temp_col].transAxes)
if p_matrix[temp_row, temp_col] < alpha_001:
ax[temp_row, temp_col].text(0.6, 0.6, '***', ha='left', va='bottom', color='black', fontsize=12, transform=ax[temp_row, temp_col].transAxes)
elif p_matrix[temp_row, temp_col] < alpha_01:
ax[temp_row, temp_col].text(0.6, 0.6, '**', ha='left', va='bottom', color='black', fontsize=12, transform=ax[temp_row, temp_col].transAxes)
elif p_matrix[temp_row, temp_col] < alpha_05:
ax[temp_row, temp_col].text(0.6, 0.6, '*',ha='left', va='bottom', color='black', fontsize=12, transform=ax[temp_row, temp_col].transAxes)
if temp_row != self.num_rol - 1:
ax[temp_row, temp_col].set_xticks([])
ax[temp_row, temp_col].set_xlabel("")
if temp_col != 0:
ax[temp_row, temp_col].set_ylabel("")
ax[temp_row, temp_col].set_yticks([])
for ax_single in ax.flat:
plt.setp(ax_single.get_xticklabels(), fontsize=5)
plt.setp(ax_single.get_yticklabels(), fontsize=5)
cbar_ax = fig.add_axes([0.95, 0.3, 0.03, 0.5]) # [left, bottom, width, height]
fig.colorbar(im, cax=cbar_ax, orientation='vertical')
plt.subplots_adjust(wspace=0, hspace=0)
return fig, ax
p = Pairs(pdata=selected_columns, corr_matrix = corr_matrix,p_matrix = p_matrix,figsize=(14, 14), dpi=200)
fig, ax = p.pairs()
调用函数前,要定义好selected_columns(csv文件的列), corr_matrix(相关系数), p_matrix(显著性)