对PointNet++分类模型提取到的全局特征进行t-SNE可视化,所使用的模型为PointNet++的pytorch实现。
模型训练过程会自动保存训练得到的最佳模型即best_model.pth,在测试代码中加载best_model.pth,在测试数据集中对提取的全局特征进行t-SNE可视化,查看各个类别的聚类效果。代码如下:
# t-SNE函数
def start_tsne(x_train,y_train):
print("正在进行初始输入数据的可视化...")
X_tsne=TSNE().fit_transform(x_train)
plt.figure(figsize=(10,10))
plt.scatter(X_tsne[:,0],X_tsne[:,1],c=y_train)
plt.colorbar()
plt.show()
# 混淆矩阵函数
def confusion(y_label,y_pred):
con_mat=confusion_matrix(y_label.astype(str),y_pred.astype(str))
#print(con_mat)
classes=['0','1','2','3','4','5','6','7','8','9']
#classes.sort()
plt.imshow(con_mat,cmap=plt.cm.Blues)
indices=range(len(con_mat))
plt.xticks(indices,classes)
plt.yticks(indices,classes)
plt.colorbar()
plt.xlabel('pred')
plt.ylabel('true')
for first_index in range(len(con_mat)):
for second_index in range(len(con_mat[first_index])):
plt.text(first_index,second_index,con_mat[second_index][first_index],va='center',ha='center')
plt.show()
#修改测试函数
def test(model,loader,num_class=40,vote_num=1):
mean_correct=[]
classifier=model.eval()
class_acc=np.zeros((num_class,3))
global_fea=[] #创建一个保存全局特征空列表
label=[] #创建标签空列表
predict=[] #创建保存预测结果的空列表
for j, (points,target) in tqdm(enumerate(loader),total=len(loader)):
label.append(target.cpu().numpy()) #将数据集中的标签添加到标签列表中
if not args.use_cpu:
points,target=points.cuda(),target.cuda()
points=points.transpose(2,1)
vote_pool=torch.zeros(target.size()[0],num_class).cuda()
for _ in range(vote_num):
pred,global_transfeat=classifier(points) #提取全局特征
vote_pool+=pred
global_transfeat=global_transfeat.cpu().numpy()#将全局特征格式转换为numpy
global_fea.append(global_transfeat)#将全局特征保存到全局特征列表中
pred=vote_pool/vote_num
pred_choice=pred.data.max(1)[1]
predict.append(pred_choice.cpu().numpy()) #将预测结果的numpy格式存入列表中
for cat in np.unique(target.cpu()):
classacc=pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum()
class_acc[cat,0]+=classacc.item()/float(points[target==cat].size()[0])
class_acc[cat,1]+=1
correct=pred_choice.eq(target.long().data).cpu().sum()
mean_correct.append(correct.item()/float(points.size()[0]))
global_fea=np.array(global_fea)# 将全局特征列表转换为元组
global_fea=np.concatenate(global_fea,axis=0)#将元组形状变为(n,1024)n为测试集样本数,我的为960,1024为特征维度
global_fea=np.squeeze(global_fea)
label=np.array(label)
label=label.reshape(960)#960为自己数据集样本个数
predict=np.array(predict)
predict=predict.reshape(960)
#print(len(global_fea))
class_acc[:,2]=class_acc[:,0]/class_acc[:,1]
class_acc=np.mean(class_acc[:,2])
instance_acc=np.mean(mean_correct)
return instance_acc, class_acc, global_fea, label, predict#返回4个变量,后面用的到
#主函数,画t-SNE图,混淆矩阵
def main(args):
def log_string(str):
logger.info(str)
print(str)
'''HYPERPARAMETER'''
os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
'''CREATEDIR'''
experiment_dir='log/classification/'+args.log_dir
'''LOG'''
args=parse_args()
logger=logging.getLogger("Model")
logger.setLevel(logging.INFO)
formatter=logging.Formatter('%(asctime)s-%(name)s-%(levelname)s-%(message)s')
file_handler=logging.FileHandler('%s/eval.txt'%experiment_dir)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
log_string('PARAMETER...')
log_string(args)
'''DATALOADING'''
log_string('Loaddataset...')
data_path='datdaset'
test_dataset=ModelNetDataLoader(root=data_path,args=args,split='test',process_data=False)
testDataLoader=torch.utils.data.DataLoader(test_dataset,batch_size=args.batch_size,shuffle=False,num_workers=10)
'''MODELLOADING'''
num_class=args.num_category
model_name=os.listdir(experiment_dir+'/logs')[0].split('.')[0]
model=importlib.import_module(model_name)
classifier=model.get_model(num_class,normal_channel=args.use_normals)
if not args.use_cpu:
classifier=classifier.cuda()
checkpoint=torch.load(str(experiment_dir)+'/checkpoints/best_model.pth')
classifier.load_state_dict(checkpoint['model_state_dict'])
with torch.no_grad():
instance_acc, class_acc, global_fea, label, predict=test(classifier.eval(),testDataLoader,vote_num=args.num_votes,num_class=num_class)
log_string('Test Instance Accuracy:%f, Class Accuracy:%f '%(instance_acc,class_acc))
#start_tsne(global_fea,label) #绘制t-sne图
confusion(label,predict) #绘制混淆矩阵
if__name__=='__main__':
args=parse_args()
main(args)
运行代码后输出t-SNE图以及混淆矩阵如下: