t-sne参数的官方文档:
https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
10个类:
cosine距离,100类:
获取特征和标签:
def get_embs(model, test_loader):
model.eval()
embs = []
labels = []
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
batch_size = targets.size(0)
inputs, targets = inputs.cuda(), targets.cuda()
with torch.no_grad():
feats, outputs = model(inputs)
#参考网络结构里进行处理
labels.append(targets.data.cpu())
mid_feat=F.adaptive_avg_pool2d(F.relu(feats[</