菜鸟认知,大佬误喷
网络输出格式为(num,num_class, h, w ),标签为(num,h , w ),标签不进行独热编码,标签.unique =[0,1,2,3,···,num_class],使用损失函数nn.CrossEntropyLoss()。
预测结果显示
以num=1为例
per.shape为(1,num_class, h, w )
def plt_predict(per,n_class):
per_ = torch.sigmoid(per)
per_ = per_.squeeze()
per_ =torch.argmax(per_,dim=0)
per_ = F.one_hot(per_, n_class).permute(2, 0,1).numpy()
for i in range(n_class):
per__ = per_[i,:,:]
plt.subplot(1,n_class,i+1)
plt.imshow(per__)
plt.show()
显示标签
def plt_label(label,n_class):
label_ = F.one_hot(label, n_class).permute(2, 0,1).numpy()
for i in range(n_class):
label__ = label_[i,:,:]
plt.subplot(1,n_class,i+1)
plt.imshow(label__)
plt.show()