原始柱状图
import matplotlib.pyplot as plt
num_list = [1.5, 0.6, 7.8, 6]
plt.bar(range(len(num_list)), num_list)
plt.show()
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.pyplot import MultipleLocator
taxinyc = {
'gdc':{
'rmse': {
'16': 0.636,
'32': 0.622,
'64': 0.617,
'128': 0.608
},
'mae': {
'16': 0.159,
'32': 0.153,
'64': 0.152,
'128': 0.151
}
},
'scl':{
'rmse': {
'16': 0.604,
'32': 0.608,
'64': 0.607
},
'mae': {
'16': 0.160,
'32': 0.151,
'64': 0.154
}
},
'dcl':{
'rmse': {
'16': 0.612,
'32': 0.608,
'64': 0.610
},
'mae': {
'16': 0.148,
'32': 0.151,
'64': 0.158
}
}
}
taxincd = {
'gdc':{
'rmse': {
'16': 0.322,
'32': 0.322,
'64': 0.321,
'128': 0.320
},
'mae': {
'16': 0.117,
'32': 0.116,
'64': 0.115,
'128': 0.119
}
},
'scl':{
'rmse': {
'16': 0.324,
'32': 0.322,
'64': 0.321
},
'mae': {
'16': 0.128,
'32': 0.116,
'64': 0.118
}
},
'dcl':{
'rmse': {
'16': 0.321,
'32': 0.322,
'64': 0.396
},
'mae': {
'16': 0.116,
'32': 0.116,
'64': 0.222
}
}
}
def plot_data(data,colors):
for i,key1 in enumerate(data.keys()): #[gdc,scl,dcl]
if i == 0:
params = {
'figure.figsize': '4, 4',
'axes.unicode_minus':False
}
xx = [0., 0.7, 1.4, 2.1]
else:
params = {
'figure.figsize': '3, 4',
'axes.unicode_minus': False
}
xx = [0., 0.7, 1.4]
plt.rcParams.update(params)
for key2 in data[key1].keys(): #[rmse,mae']
fig, ax = plt.subplots()
hidden_dims = list(data[key1][key2].keys())
values = list(data[key1][key2].values())
x_center = [index + 0.3 for index in xx]
plt.bar(xx, values, width=0.6, align='edge',color=colors)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
if key2 == 'rmse':
yrange = 0.7
ax.yaxis.set_major_locator(MultipleLocator(0.1))
else:
yrange = 0.18
ax.yaxis.set_major_locator(MultipleLocator(0.03))
plt.ylim(0.,yrange) # y轴取值范围
plt.ylabel(str.upper(key2))
plt.xticks(x_center, hidden_dims) # 这儿的0.3是配合宽度0.6来的,是他的一半,目的是让刻度线在柱子的中间
# plt.xlabel("特征", labelpad=8.5)
plt.axis('on') # 增加这行关闭坐标轴显示,但仍有空白区域
# plt.margins(0.1) 图内距离坐标轴0.1
plt.subplots_adjust(left=0.2) #整个图距离画布左边距0.2,防止ylabel消失
# 关键在于bbox_inches = 'tight',pad_inches = 0,去掉空白区域
plt.savefig('result/{}_{}.png'.format(key1,key2), bbox_inches='tight', pad_inches=0)
plt.show()
plt.close(fig)
if __name__ == '__main__':
colors = ['#35478C','#4E7AC7','#2FB2F0','#ADD5F7']
plot_data(taxinyc,colors)
# num_list = [1.5, 0.6, 7.8, 6]
# plt.bar(range(len(num_list)), num_list)
# plt.show()
效果图:
折线图版本
import numpy as np
import matplotlib.pyplot as plt
taxinyc = {
'gdc':{
'rmse': {
'16': 0.636,
'32': 0.622,
'64': 0.617,
'128': 0.608
},
'mae': {
'16': 0.159,
'32': 0.153,
'64': 0.152,
'128': 0.151
}
},
'scl':{
'rmse': {
'16': 0.604,
'32': 0.608,
'64': 0.607
},
'mae': {
'16': 0.160,
'32': 0.151,
'64': 0.154
}
},
'dcl':{
'rmse': {
'16': 0.612,
'32': 0.608,
'64': 0.610
},
'mae': {
'16': 0.148,
'32': 0.151,
'64': 0.158
}
}
}
taxicd = {
'gdc':{
'rmse': {
'16': 0.322,
'32': 0.322,
'64': 0.321,
'128': 0.320
},
'mae': {
'16': 0.117,
'32': 0.116,
'64': 0.115,
'128': 0.119
}
},
'scl':{
'rmse': {
'16': 0.324,
'32': 0.322,
'64': 0.321
},
'mae': {
'16': 0.128,
'32': 0.116,
'64': 0.118
}
},
'dcl':{
'rmse': {
'16': 0.321,
'32': 0.322,
'64': 0.396
},
'mae': {
'16': 0.116,
'32': 0.116,
'64': 0.222
}
}
}
t1=['16','32','64','128']
t2= [16,32,64]
mae = [0.159,0.153,0.152,0.151]
rmse = [0.148,0.151,0.158]
x1 = [0.5,1.0,1.5,2.0]
x2 = [0.5,1.0,1.5]
data = taxicd
figure,ax=plt.subplots(1,3,sharey=True,figsize=(12,4))
index = ['(a)','(b)','(c)']
for i,(idx,key1) in enumerate(zip(index,data.keys())): # [gdc,scl,dcl]
for key2 in ['mae']: # [rmse,mae']
hidden_dims = list(data[key1][key2].keys())
values = list(data[key1][key2].values())
title = idx + 'hidden units of ' + str.upper(key1)
if i == 0:
ax[i].plot(x1, values, 'ro--')
ax[i].set_ylabel(str.upper(key2),fontdict={'fontsize':12})
ax[i].set_xticks(x1)
ax[i].set_xticklabels(hidden_dims,fontdict={'fontsize':12})
else:
ax[i].plot(x2, values, 'ro--')
ax[i].set_xticks(x2)
ax[i].set_xticklabels(hidden_dims,fontdict={'fontsize':12})
ax[i].set_title(title,fontdict={'fontsize':12})
plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
plt.savefig('./taxicd_line_graph_hidden_units.png', bbox_inches='tight', pad_inches=0.1)
figure.show()
参考:
画柱状图
控制柱子之间的间距
matplotlib plt.bar函数
matplotlib plt.xtricks函数