import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from umap import UMAP
# Load MNIST dataset
mnist = fetch_openml('CIFAR_10', version=1, cache=True)
X = mnist.data
y = mnist.target.astype(int)
# Apply UMAP
umap = UMAP(n_components=2)
X_umap = umap.fit_transform(X)
# Plot UMAP output with class colors
plt.scatter(X_umap[:, 0], X_umap[:, 1], c=y, cmap='tab10', s=1)
plt.colorbar()
plt.title('UMAP Visualization of CIFAR10')
plt.show()
umap存在维度诅咒,所以需要首先进行降维