基本思想就是把测试数据输入模型,然后对模型提取的特征(未经过分类器)的部分进行降维绘图
1. 先引入包
import torch
from sklearn.manifold import TSNE # 这个是绘图关键
import random
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
from torchvision import datasets, transforms
2. 设置随机种子
为保证结果可复现,设置了随机种子
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
setup_seed(1337)
3. 准备测试数据及模型
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 对读取数据做个处理,打个包
testset = datasets.CIFAR10(root='../data/MNIST/', train=False,
download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
shuffle=False, num_workers=2)
model_file = ["centralized_net.pkl"]
model = torch.load(model_name)
4. 输入测试数据得到特征表示
model.eval()
with torch.no_grad():
for i, (image_batch, label_batch) in enumerate(testloader):
image_batch, label_batch = image_batch.cuda(), label_batch.cuda()
label_batch = label_batch.long().squeeze()
inputs = image_batch
logits, feature = model(inputs)
if i == 0:
feature_bank = feature
label_bank = label_batch
logits_bank = logits
else:
feature_bank = torch.cat((feature_bank, feature))
label_bank = torch.cat((label_bank, label_batch))
logits_bank = torch.cat((logits_bank, logits))
5. 绘图
针对feature_bank
和label_bank
进行绘图
feature_bank = feature_bank.cpu().numpy()
label_bank = label_bank.cpu().numpy()
p, pseu = torch.max(torch.softmax(logits_bank, dim=-1), dim=-1)
prob_bank = p.cpu().numpy()
tsne = TSNE(2)
output = tsne.fit_transform(feature_bank) # feature进行降维,降维至2维表示
# 带真实值类别
for i in range(10): # 对每类的数据画上特定颜色的点
index = (label_bank==i)
plt.scatter(output[index, 0], output[index, 1],s=5, cmap=plt.cm.Spectral)
plt.legend(["0", "1", "2", "3", "4", "5", "6","7", "8", "9"])
plt.show()