plt.bar柱状图减小柱子之间的间隔

原始柱状图

    import matplotlib.pyplot as plt
    num_list = [1.5, 0.6, 7.8, 6]
    plt.bar(range(len(num_list)), num_list)
    plt.show()

在这里插入图片描述

import matplotlib.pyplot as plt
import numpy as np

from matplotlib.pyplot import MultipleLocator

taxinyc = {
    'gdc':{
        'rmse': {
            '16': 0.636,
            '32': 0.622,
            '64': 0.617,
            '128': 0.608
        },
        'mae': {
            '16': 0.159,
            '32': 0.153,
            '64': 0.152,
            '128': 0.151
        }
    },
    'scl':{
        'rmse': {
            '16': 0.604,
            '32': 0.608,
            '64': 0.607
        },
        'mae': {
            '16': 0.160,
            '32': 0.151,
            '64': 0.154
        }
    },
    'dcl':{
        'rmse': {
            '16': 0.612,
            '32': 0.608,
            '64': 0.610
        },
        'mae': {
            '16': 0.148,
            '32': 0.151,
            '64': 0.158
        }
    }
}

taxincd = {
    'gdc':{
        'rmse': {
            '16': 0.322,
            '32': 0.322,
            '64': 0.321,
            '128': 0.320
        },
        'mae': {
            '16': 0.117,
            '32': 0.116,
            '64': 0.115,
            '128': 0.119
        }
    },
    'scl':{
        'rmse': {
            '16': 0.324,
            '32': 0.322,
            '64': 0.321
        },
        'mae': {
            '16': 0.128,
            '32': 0.116,
            '64': 0.118
        }
    },
    'dcl':{
        'rmse': {
            '16': 0.321,
            '32': 0.322,
            '64': 0.396
        },
        'mae': {
            '16': 0.116,
            '32': 0.116,
            '64': 0.222
        }
    }
}


def plot_data(data,colors):

    for i,key1 in enumerate(data.keys()): #[gdc,scl,dcl]
        if i == 0:
            params = {
                'figure.figsize': '4, 4',
                'axes.unicode_minus':False
            }
            xx = [0., 0.7, 1.4, 2.1]
        else:
            params = {
                'figure.figsize': '3, 4',
                'axes.unicode_minus': False
            }
            xx = [0., 0.7, 1.4]

        plt.rcParams.update(params)

        for key2 in data[key1].keys(): #[rmse,mae']
            fig, ax = plt.subplots()
            hidden_dims = list(data[key1][key2].keys())
            values = list(data[key1][key2].values())

            x_center = [index + 0.3 for index in xx]
            plt.bar(xx, values, width=0.6, align='edge',color=colors)

            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)



            if key2 == 'rmse':
                yrange = 0.7
                ax.yaxis.set_major_locator(MultipleLocator(0.1))
            else:
                yrange = 0.18
                ax.yaxis.set_major_locator(MultipleLocator(0.03))

            plt.ylim(0.,yrange)  # y轴取值范围
            plt.ylabel(str.upper(key2))
            plt.xticks(x_center, hidden_dims)  # 这儿的0.3是配合宽度0.6来的,是他的一半,目的是让刻度线在柱子的中间
            # plt.xlabel("特征", labelpad=8.5)

            plt.axis('on')  # 增加这行关闭坐标轴显示,但仍有空白区域
            # plt.margins(0.1) 图内距离坐标轴0.1
            plt.subplots_adjust(left=0.2) #整个图距离画布左边距0.2,防止ylabel消失
            # 关键在于bbox_inches = 'tight',pad_inches = 0,去掉空白区域



            plt.savefig('result/{}_{}.png'.format(key1,key2), bbox_inches='tight', pad_inches=0)

            plt.show()
            plt.close(fig)

if __name__ == '__main__':
    colors = ['#35478C','#4E7AC7','#2FB2F0','#ADD5F7']
    plot_data(taxinyc,colors)

    # num_list = [1.5, 0.6, 7.8, 6]
    # plt.bar(range(len(num_list)), num_list)
    # plt.show()

效果图:
在这里插入图片描述
折线图版本

import numpy as np
import matplotlib.pyplot as plt

taxinyc = {
    'gdc':{
        'rmse': {
            '16': 0.636,
            '32': 0.622,
            '64': 0.617,
            '128': 0.608
        },
        'mae': {
            '16': 0.159,
            '32': 0.153,
            '64': 0.152,
            '128': 0.151
        }
    },
    'scl':{
        'rmse': {
            '16': 0.604,
            '32': 0.608,
            '64': 0.607
        },
        'mae': {
            '16': 0.160,
            '32': 0.151,
            '64': 0.154
        }
    },
    'dcl':{
        'rmse': {
            '16': 0.612,
            '32': 0.608,
            '64': 0.610
        },
        'mae': {
            '16': 0.148,
            '32': 0.151,
            '64': 0.158
        }
    }
}

taxicd = {
    'gdc':{
        'rmse': {
            '16': 0.322,
            '32': 0.322,
            '64': 0.321,
            '128': 0.320
        },
        'mae': {
            '16': 0.117,
            '32': 0.116,
            '64': 0.115,
            '128': 0.119
        }
    },
    'scl':{
        'rmse': {
            '16': 0.324,
            '32': 0.322,
            '64': 0.321
        },
        'mae': {
            '16': 0.128,
            '32': 0.116,
            '64': 0.118
        }
    },
    'dcl':{
        'rmse': {
            '16': 0.321,
            '32': 0.322,
            '64': 0.396
        },
        'mae': {
            '16': 0.116,
            '32': 0.116,
            '64': 0.222
        }
    }
}

t1=['16','32','64','128']
t2= [16,32,64]

mae = [0.159,0.153,0.152,0.151]
rmse = [0.148,0.151,0.158]

x1 = [0.5,1.0,1.5,2.0]
x2 = [0.5,1.0,1.5]


data = taxicd

figure,ax=plt.subplots(1,3,sharey=True,figsize=(12,4))


index = ['(a)','(b)','(c)']
for i,(idx,key1) in enumerate(zip(index,data.keys())):  # [gdc,scl,dcl]
    for key2 in ['mae']:  # [rmse,mae']

        hidden_dims = list(data[key1][key2].keys())
        values = list(data[key1][key2].values())
        title = idx + 'hidden units of ' + str.upper(key1)

        if i == 0:
            ax[i].plot(x1, values, 'ro--')
            ax[i].set_ylabel(str.upper(key2),fontdict={'fontsize':12})
            ax[i].set_xticks(x1)
            ax[i].set_xticklabels(hidden_dims,fontdict={'fontsize':12})
        else:
            ax[i].plot(x2, values, 'ro--')
            ax[i].set_xticks(x2)
            ax[i].set_xticklabels(hidden_dims,fontdict={'fontsize':12})
        ax[i].set_title(title,fontdict={'fontsize':12})

plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
plt.savefig('./taxicd_line_graph_hidden_units.png', bbox_inches='tight', pad_inches=0.1)
figure.show()

在这里插入图片描述

参考:
画柱状图
控制柱子之间的间距
matplotlib plt.bar函数
matplotlib plt.xtricks函数

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值