def plot_embedding(data, label, train_index,title):
x_min, x_max = np.min(data, 0), np.max(data, 0)
data = (data - x_min) / (x_max - x_min)
fig = plt.figure()
ax = plt.subplot(111)
for i in range(data.shape[0]):
if i not in train_index:
plt.text(data[i, 0], data[i, 1], str(label[i]),
color=plt.cm.Set1(label[i] / 10.),
fontdict={'weight': 'bold', 'size': 9})
else:
plt.text(data[i, 0], data[i, 1], str(label[i]),
color=plt.cm.Set3(label[i] / 10.),
fontdict={'weight': 'bold', 'size': 9})
plt.xticks([])
plt.yticks([])
plt.title(title)
plt.show()
if __name__ == '__main__':
# pdist = torch.nn.PairwiseDistance()
data = torch.tensor([[0.1, 0.1,0.3],
[0.5, 0.6,0.5],
[0.6,0.7,0.8]])
data = data.numpy()
label = torch.tensor([0,0,1])
label = label.numpy()
train_index = [1]
n_samples = data.shape[0]
n_features = data.shape[1]
print(data,label,n_samples,n_features)
tsne = TSNE(n_components=3, init='pca', random_state=0)
result = tsne.fit_transform(data)
# fig = plot_embedding(result[:,0:2], label,'t-SNE embedding of the data')
plot_embedding(result, label, train_index,'t-SNE embedding of the data')
颜色和字体的设计可自行调整