import torch
import torch.nn.functional as F
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
# 19个类别,每个类别2048维特征长度
features = torch.rand(19, 2048)
tsne = TSNE(n_components=3)
tsne.fit_transform(features)
print(tsne.embedding_.shape)
# plot
x = tsne.embedding_[:, 0]
y = tsne.embedding_[:, 1]
# print(x)
# print(y)
# print(z)
plt.figure()
if tsne.embedding_.shape[1] == 2:
ax = plt.gca()
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.scatter(x, y, c='r', s=20, alpha=0.5)
for i in range(len(x)):
ax.text(x[i],y[i],i)
elif tsne.embedding_.shape[1] == 3:
z = tsne.embedding_[:, 2]
ax = plt.gca(projection='3d')
ax.set_ylabel('z')
ax.scatter(x, y, z, c='r', s=20, alpha=0.5)
for i in range(len(x)):
ax.text(x[i],y[i],z[i],i)
plt.show()