Matplotlib设置网格线之major和minor

Matplotlib设置网格线之major和minor

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
data_dict = {
    0: 7.088511943817139,
    1: 1.941696047782898,
    2: 4.185360431671143,
    3: 3.784451723098755,
    4: 5.676402568817139,
    5: 6.908511161804199,
    6: 0.0035664942115545273,
    7: 0.00400047842413187,
    8: 0.00456645293161273,
    9: 0.003707129508256912,
    10: 20.32890510559082,
    11: 0.0040948097594082355
}
encoder_dict = {k: data_dict[k] for k in range(6)}
decoder_dict = {k: data_dict[k] for k in range(6, 12)}
data1 = np.array(list(encoder_dict.values())).reshape(1, -1)
data2 = np.array(list(decoder_dict.values())).reshape(1, -1)

# Create a figure and axis objects
fig, (ax1, ax2) = plt.subplots(2, 1, sharey=True)

# Plot the first heatmap
im1 = ax1.imshow(data1, cmap='Oranges', interpolation='nearest')
# im1.set_gid(True)
ax1.set_yticks([])
ax1.set_xticks([0, 1, 2,3,4,5])
ax1.xaxis.set_minor_locator(MultipleLocator(0.5))
ax1.set_ylabel('Encoder')
ax1.set_xlabel('Layer')

# Plot the second heatmap
im2 = ax2.imshow(data2, cmap='Oranges', interpolation='nearest')
# im2.set_gid(True)
ax2.set_yticks([])
ax2.set_xticks([0, 1, 2,3,4,5])
ax2.xaxis.set_minor_locator(MultipleLocator(0.5))
ax2.set_xlabel('Layer')
ax2.set_ylabel('Decoder')

# Add gridlines to both subplots
for ax in [ax1, ax2]:
    ax.grid(True, which='minor', axis='both', linestyle='-', color='gray', alpha=0.5)
# 添加颜色条,颜色条的位置位于底部横着放
cbar = fig.colorbar(im2, ax=[ax1, ax2], orientation='horizontal', pad=0.2)

# 显示图形

plt.show()
# 保存图形为pdf
fig.savefig('norm_visualization.pdf', format='pdf', dpi=300)
  • 8
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值