fig, ax = plt.subplots(3, 3)
ax = ax.flatten()
np.random.seed(0)
for i in range(9):
weight = np.random.random([8, 8])
im = ax[i].imshow(weight)
fig.colorbar(im, ax=[ax[0], ax[1], ax[2], ax[3], ax[4], ax[5], ax[6], ax[7], ax[8]], fraction=0.03, pad=0.05)
plt.savefig('tjn.png', bbox_inches='tight')
plt.show()
fraction=0.03调节colorbar的大小