当使用ax.imshow()函数画图时,无法简单的像使用plt.plot(label=xxx)时添加图例。
.imshow()的使用方法不再赘述。假设你要画的数组为base_map为:
base_map = [
"SSSDD",
"SDSSD",
"DSSDD",
"SSDSD",
"DSSSG"
]
添加图例可以使用这段代码:
base_map = np.array([list(s) for s in base_map])
elements = np.unique(base_map)
map_to_plot = np.zeros_like(base_map, dtype=float)
for i, el in enumerate(elements):
map_to_plot[base_map == el] = float(i)
plt.tick_params(axis='both',
which='both',
bottom=False,
left=False,
labelbottom=False,
labelleft=False)
im = plt.imshow(map_to_plot)
if legend:
# pdb.set_trace()
# values = np.unique(map_to_plot.ravel())
values = np.array([0., 0.5, 1])
colors = [im.cmap(im.norm(value)) for value in values]
labels = ['Danger', 'Goal', 'Safe']
patches = [
mpatches.Patch(color=colors[i], label=labels[i])
for i in range(len(values))]
plt.legend(handles=patches, bbox_to_anchor=(1.01, 1), loc=2,
borderaxespad=0., frameon=False)
结果如图: