下三角散点图和上三角热图拼接

效果图:

参考链接:

python版本的Pairs图_pair图-CSDN博客

class Pairs(object):
    def __init__(self, pdata: pd.DataFrame,corr_matrix, p_matrix, figsize: tuple = (30, 30), dpi: int = 300):
        pdata = pdata.select_dtypes(include=np.number)
        self.figsize = figsize
        self.dpi = dpi
        self.rawdata = pdata
        self.num_rol = self.rawdata.shape[1]
        self.data_col_name = self.rawdata.columns.tolist()
        self.corr_matrix = corr_matrix
        self.p_matrix = p_matrix
    def pairs(self):
        fig, ax = plt.subplots(ncols=self.num_rol,nrows=self.num_rol,figsize=self.figsize,dpi=self.dpi)
        for temp_col in range(self.num_rol):
            for temp_row in range(self.num_rol):
                if temp_row == temp_col:
                    # plot 主对角线
                    sns.histplot(data=self.rawdata, x=self.data_col_name[temp_col], bins=20, ax=ax[temp_row, temp_col])
        
                    ax[temp_row, temp_col].set_ylabel('lv%', fontsize=8,rotation=0, labelpad=10)
                    ax[temp_row, temp_col].set_xlabel('lv%', fontsize=8,rotation=90)
                if temp_row > temp_col:
                    # 下对角线
                    sns.scatterplot(data=self.rawdata,x=self.data_col_name[temp_col],y=self.data_col_name[temp_row],ax=ax[temp_row, temp_col], s=5)
                   
                    ax[temp_row, temp_col].set_ylabel(self.data_col_name[temp_row], fontsize=8, rotation=0, labelpad=20)
                    ax[temp_row, temp_col].set_xlabel(self.data_col_name[temp_col], fontsize=8, rotation=90)
                elif temp_row < temp_col:
                    #上对角线
                    ax[temp_row, temp_col].set_axis_off()
                    im = ax[temp_row, temp_col].imshow(np.array([[corr_matrix.iloc[temp_row, temp_col]]]), cmap='coolwarm', vmin=-0.1, vmax=1)
                    p_value_text = f'{corr_matrix.iloc[temp_row, temp_col]:.2f}'
                    ax[temp_row, temp_col].text(0.5, 0.5, p_value_text, ha='center', va='center', color='white', fontsize=14, transform=ax[temp_row, temp_col].transAxes)
                    if p_matrix[temp_row, temp_col] < alpha_001:
                        ax[temp_row, temp_col].text(0.6, 0.6, '***', ha='left', va='bottom', color='black', fontsize=12, transform=ax[temp_row, temp_col].transAxes)
                    elif p_matrix[temp_row, temp_col] < alpha_01:
                        ax[temp_row, temp_col].text(0.6, 0.6, '**', ha='left', va='bottom', color='black', fontsize=12, transform=ax[temp_row, temp_col].transAxes)
                    elif p_matrix[temp_row, temp_col] < alpha_05:
                        ax[temp_row, temp_col].text(0.6, 0.6, '*',ha='left', va='bottom', color='black', fontsize=12, transform=ax[temp_row, temp_col].transAxes)
                if temp_row != self.num_rol - 1:
                    ax[temp_row, temp_col].set_xticks([])
                    ax[temp_row, temp_col].set_xlabel("")

                if temp_col != 0:
                    ax[temp_row, temp_col].set_ylabel("")
                    ax[temp_row, temp_col].set_yticks([])
        for ax_single in ax.flat:
            plt.setp(ax_single.get_xticklabels(), fontsize=5)
            plt.setp(ax_single.get_yticklabels(), fontsize=5)
            
        cbar_ax = fig.add_axes([0.95, 0.3, 0.03, 0.5])  # [left, bottom, width, height]
        fig.colorbar(im, cax=cbar_ax, orientation='vertical')
        plt.subplots_adjust(wspace=0, hspace=0)
        return fig, ax

 
p = Pairs(pdata=selected_columns,  corr_matrix = corr_matrix,p_matrix = p_matrix,figsize=(14, 14), dpi=200)
fig, ax = p.pairs()

调用函数前,要定义好selected_columns(csv文件的列),  corr_matrix(相关系数), p_matrix(显著性)

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值