matplotlib
colors 那里做了一下处理,使越大的值对应的颜色越深。除了 viridis
这种渐进色,还有很多其他的,详见官方文档
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.cm as cm
from matplotlib.colors import Normalize
fig = plt.figure(figsize=(6.4*2, 4.8*1.1))
ax2d = fig.add_subplot(1, 2, 1)
colors = dataset_P.max() - dataset_P
norm = Normalize(vmin=colors.min(), vmax=colors.max())
cmap = cm.get_cmap('viridis')
ax2d.scatter(S_tsne[:, 0], S_tsne[:, 1], s=30, c=colors, alpha=0.8,
cmap=cmap, norm=norm, linewidth=0.5, edgecolor='w')
ax2d.set_title('State distribution(TSNE) via Policy')
ax2d.set_xlabel('State[0]')
ax2d.set_ylabel('State[1]')
handles = [mpatches.Patch(color=cmap(norm(dataset_P.max()-i)), label=str(i), alpha=0.8)
for i in range(9)]
ax2d.legend(handles=handles, loc='upper right', title='Policy')
ax2d = fig.add_subplot(1, 2, 2)
colors = dataset_V.max()-dataset_V
norm = Normalize(vmin=colors.min(), vmax=colors.max())
cmap = cm.get_cmap('viridis')
ax2d.scatter(S_tsne[:, 0], S_tsne[:, 1], s=30, c=colors, alpha=0.8,
cmap=cmap, norm=norm, linewidth=0.5, edgecolor='w')
ax2d.set_title('State distribution(TSNE) via Value')
ax2d.set_xlabel('State[0]')
ax2d.set_ylabel('State[1]')
handles = [mpatches.Patch(color=cmap(norm(dataset_V.max()-i)), label=f'{i:.3f}', alpha=0.8)
for i in np.linspace(dataset_V.min(), dataset_V.max(), 7)]
ax2d.legend(handles=handles, loc='upper right', title='Value')
fig.tight_layout()
plt.savefig(f'./svp.png')
seaborn
fig = plt.figure(figsize=(6.4*2, 4.8*2*1.1))
ax2d = fig.add_subplot(2, 2, 1)
colors = dataset_P.max() - dataset_P
norm = Normalize(vmin=colors.min(), vmax=colors.max())
cmap = cm.get_cmap('viridis')
ax2d.scatter(S_tsne[:, 0], S_tsne[:, 1], s=30, c=colors, alpha=0.8,
cmap=cmap, norm=norm, linewidth=0.5, edgecolor='w')
ax2d.set_title('State distribution(TSNE) by matplotlib')
ax2d.set_xlabel('State[0]')
ax2d.set_ylabel('State[1]')
handles = [mpatches.Patch(color=cmap(norm(dataset_P.max()-i)), label=str(i), alpha=0.8)
for i in range(9)]
ax2d.legend(handles=handles, loc='upper right', title='Policy')
ax2d = fig.add_subplot(2, 2, 2)
df_tsne = pd.DataFrame(
{'State[0]': S_tsne[:, 0], 'State[1]': S_tsne[:, 1], 'class': dataset_P})
# ax2d.scatter(S_tsne[:, 0], S_tsne[:, 1], s=1)
sns.scatterplot(data=df_tsne, hue='class', x='State[0]', y='State[1]', ax=ax2d, legend='full')
ax2d.set_title('State distribution(TSNE) by seaborn')
ax2d.legend(loc='upper right', title='Policy')
ax2d.set_xlabel('State[0]')
ax2d.set_ylabel('State[1]')
ax2d = fig.add_subplot(2, 2, 3)
colors = dataset_V.max()-dataset_V
norm = Normalize(vmin=colors.min(), vmax=colors.max())
cmap = cm.get_cmap('viridis')
ax2d.scatter(S_tsne[:, 0], S_tsne[:, 1], s=30, c=colors, alpha=0.8,
cmap=cmap, norm=norm, linewidth=0.5, edgecolor='w')
ax2d.set_title('State distribution(TSNE) via Value')
ax2d.set_xlabel('State[0]')
ax2d.set_ylabel('State[1]')
handles = [mpatches.Patch(color=cmap(norm(dataset_V.max()-i)), label=f'{i:.3f}', alpha=0.8)
for i in np.linspace(dataset_V.min(), dataset_V.max(), 7)]
ax2d.legend(handles=handles, loc='upper right', title='Value')
ax2d = fig.add_subplot(2, 2, 4)
df_tsne = pd.DataFrame(
{'State[0]': S_tsne[:, 0], 'State[1]': S_tsne[:, 1], 'class': dataset_V})
sns.scatterplot(data=df_tsne, hue='class', x='State[0]', y='State[1]', ax=ax2d)
ax2d.set_title('State distribution(TSNE) by seaborn')
ax2d.legend(loc='upper right', title='Value')
ax2d.set_xlabel('State[0]')
ax2d.set_ylabel('State[1]')
fig.tight_layout()
plt.savefig('./test.png')