def plot_clustering(z_run, labels, engine ='plotly', download = False, folder_name ='clustering'):
"""
Given latent variables for all timeseries, and output of k-means, run PCA and tSNE on latent vectors and color the points using cluster_labels.
:param z_run: Latent vectors for all input tensors
:param labels: Cluster labels for all input tensors
:param engine: plotly/matplotlib
:param download: If true, it will download plots in `folder_name`
:param folder_name: Download folder to dump plots
:return:
"""
def plot_clustering_plotly(z_run, labels):
labels = labels[:z_run.shape[0]] # because of weird batch_size
hex_colors = []
for _ in np.unique(labels):
hex_colors.append('#%06X' % randint(0, 0xFFFFFF))
colors = [hex_colors[int(i)] for i in labels]
z_run_pca = TruncatedSVD(n_components=3).fit_transform(z_run)
z_run_tsne = TSNE(perplexity=80, min_grad_norm=1E-12, n_iter=3000).fit_transform(z_run)
trace = Scatter(
x=z_run_pca[:, 0],
y=z_run_pca[:, 1],
mode='markers',
marker=dict(color=colors)
)
data = Data([trace])
layout = Layout(
title='PCA on z_run',
showlegend=False
)
fig = Figure(data=data, layout=layout)
plotly.offline.iplot(fig)
trace = Scatter(
x=z_run_tsne[:, 0],
y=z_run_tsne[:, 1],
mode='markers',
marker=dict(color=colors)
)
data = Data([trace])
layout = Layout(
title='tSNE on z_run',
showlegend=False
)
fig = Figure(data=data, layout=layout)
plotly.offline.iplot(fig)
def plot_clustering_matplotlib(z_run, labels, download, folder_name):
labels = labels[:z_run.shape[0]] # because of weird batch_size
hex_colors = []
for _ in np.unique(labels):
hex_colors.append('#%06X' % randint(0, 0xFFFFFF))
colors = [hex_colors[int(i)] for i in labels]
z_run_pca = TruncatedSVD(n_components=3).fit_transform(z_run)
z_run_tsne = TSNE(perplexity=80, min_grad_norm=1E-12, n_iter=3000).fit_transform(z_run)
plt.scatter(z_run_pca[:, 0], z_run_pca[:, 1], c=colors, marker='*', linewidths=0)
plt.title('PCA on z_run')
if download:
if os.path.exists(folder_name):
pass
else:
os.mkdir(folder_name)
plt.savefig(folder_name + "/pca.png")
else:
plt.show()
plt.scatter(z_run_tsne[:, 0], z_run_tsne[:, 1], c=colors, marker='*', linewidths=0)
plt.title('tSNE on z_run')
if download:
if os.path.exists(folder_name):
pass
else:
os.mkdir(folder_name)
plt.savefig(folder_name + "/tsne.png")
else:
plt.show()
if (download == False) & (engine == 'plotly'):
plot_clustering_plotly(z_run, labels)
if (download) & (engine == 'plotly'):
print("Can't download plotly plots")
if engine == 'matplotlib':
plot_clustering_matplotlib(z_run, labels, download, folder_name)
对数据用TruncatedSVD,TSEN降维后根据真实标签可视化
最新推荐文章于 2024-04-23 23:38:26 发布