[matplotlib] 标签分类图

matplotlib

colors 那里做了一下处理,使越大的值对应的颜色越深。除了 viridis 这种渐进色,还有很多其他的,详见官方文档

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.cm as cm
from matplotlib.colors import Normalize

fig = plt.figure(figsize=(6.4*2, 4.8*1.1))

ax2d = fig.add_subplot(1, 2, 1)
colors = dataset_P.max() - dataset_P
norm = Normalize(vmin=colors.min(), vmax=colors.max())
cmap = cm.get_cmap('viridis')
ax2d.scatter(S_tsne[:, 0], S_tsne[:, 1], s=30, c=colors, alpha=0.8,
                cmap=cmap, norm=norm, linewidth=0.5, edgecolor='w')
ax2d.set_title('State distribution(TSNE) via Policy')
ax2d.set_xlabel('State[0]')
ax2d.set_ylabel('State[1]')
handles = [mpatches.Patch(color=cmap(norm(dataset_P.max()-i)), label=str(i), alpha=0.8)
            for i in range(9)]
ax2d.legend(handles=handles, loc='upper right', title='Policy')

ax2d = fig.add_subplot(1, 2, 2)
colors = dataset_V.max()-dataset_V
norm = Normalize(vmin=colors.min(), vmax=colors.max())
cmap = cm.get_cmap('viridis')
ax2d.scatter(S_tsne[:, 0], S_tsne[:, 1], s=30, c=colors, alpha=0.8,
                cmap=cmap, norm=norm, linewidth=0.5, edgecolor='w')
ax2d.set_title('State distribution(TSNE) via Value')
ax2d.set_xlabel('State[0]')
ax2d.set_ylabel('State[1]')
handles = [mpatches.Patch(color=cmap(norm(dataset_V.max()-i)), label=f'{i:.3f}', alpha=0.8)
            for i in np.linspace(dataset_V.min(), dataset_V.max(), 7)]
ax2d.legend(handles=handles, loc='upper right', title='Value')

fig.tight_layout()
plt.savefig(f'./svp.png')

在这里插入图片描述

seaborn

fig = plt.figure(figsize=(6.4*2, 4.8*2*1.1))

ax2d = fig.add_subplot(2, 2, 1)
colors = dataset_P.max() - dataset_P
norm = Normalize(vmin=colors.min(), vmax=colors.max())
cmap = cm.get_cmap('viridis')
ax2d.scatter(S_tsne[:, 0], S_tsne[:, 1], s=30, c=colors, alpha=0.8,
                cmap=cmap, norm=norm, linewidth=0.5, edgecolor='w')
ax2d.set_title('State distribution(TSNE) by matplotlib')
ax2d.set_xlabel('State[0]')
ax2d.set_ylabel('State[1]')
handles = [mpatches.Patch(color=cmap(norm(dataset_P.max()-i)), label=str(i), alpha=0.8)
            for i in range(9)]
ax2d.legend(handles=handles, loc='upper right', title='Policy')

ax2d = fig.add_subplot(2, 2, 2)
df_tsne = pd.DataFrame(
    {'State[0]': S_tsne[:, 0], 'State[1]': S_tsne[:, 1], 'class': dataset_P})
# ax2d.scatter(S_tsne[:, 0], S_tsne[:, 1], s=1)
sns.scatterplot(data=df_tsne, hue='class', x='State[0]', y='State[1]', ax=ax2d, legend='full')
ax2d.set_title('State distribution(TSNE) by seaborn')
ax2d.legend(loc='upper right', title='Policy')
ax2d.set_xlabel('State[0]')
ax2d.set_ylabel('State[1]')

ax2d = fig.add_subplot(2, 2, 3)
colors = dataset_V.max()-dataset_V
norm = Normalize(vmin=colors.min(), vmax=colors.max())
cmap = cm.get_cmap('viridis')
ax2d.scatter(S_tsne[:, 0], S_tsne[:, 1], s=30, c=colors, alpha=0.8,
                cmap=cmap, norm=norm, linewidth=0.5, edgecolor='w')
ax2d.set_title('State distribution(TSNE) via Value')
ax2d.set_xlabel('State[0]')
ax2d.set_ylabel('State[1]')
handles = [mpatches.Patch(color=cmap(norm(dataset_V.max()-i)), label=f'{i:.3f}', alpha=0.8)
            for i in np.linspace(dataset_V.min(), dataset_V.max(), 7)]
ax2d.legend(handles=handles, loc='upper right', title='Value')

ax2d = fig.add_subplot(2, 2, 4)
df_tsne = pd.DataFrame(
    {'State[0]': S_tsne[:, 0], 'State[1]': S_tsne[:, 1], 'class': dataset_V})
sns.scatterplot(data=df_tsne, hue='class', x='State[0]', y='State[1]', ax=ax2d)
ax2d.set_title('State distribution(TSNE) by seaborn')
ax2d.legend(loc='upper right', title='Value')
ax2d.set_xlabel('State[0]')
ax2d.set_ylabel('State[1]')

fig.tight_layout()
plt.savefig('./test.png')

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值