分类任务中 一些函数 acc 、roc、混淆矩阵

from sklearn.metrics import roc_curve,auc
from prettytable import PrettyTable
class ConfusionMatrix(object):
    """
    注意,如果显示的图像不全,是matplotlib版本问题
    本例程使用matplotlib-3.2.1(windows and ubuntu)绘制正常
    需要额外安装prettytable库
    """
    def __init__(self, num_classes: int, labels: list):
        self.matrix = np.zeros((num_classes, num_classes))
        self.num_classes = num_classes
        self.labels = labels

    def update(self, preds, labels):
        for p, t in zip(preds, labels):
            self.matrix[p, t] += 1

    def summary(self):
        # calculate accuracy
        sum_TP = 0
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]
        acc = sum_TP / np.sum(self.matrix)
        print("the model accuracy is ", acc)

        # precision, recall, specificity
        table = PrettyTable()
        table.field_names = ["", "Precision", "Recall", "Specificity","f1_score"]
        for i in range(self.num_classes):
            TP = self.matrix[i, i]
            FP = np.sum(self.matrix[i, :]) - TP
            FN = np.sum(self.matrix[:, i]) - TP
            TN = np.sum(self.matrix) - TP - FP - FN
            Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
            Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
            Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
            f1_score = 1./(1/(Precision+0.00001) + 1/(Recall+0.00001))
            table.add_row([self.labels[i], Precision, Recall, Specificity,f1_score])
        print(table)

    def plot(self):
        matrix = self.matrix
        print(matrix)
        plt.imshow(matrix, cmap=plt.cm.Blues)

        # 设置x轴坐标label
        plt.xticks(range(self.num_classes), self.labels, rotation=45)
        # 设置y轴坐标label
        plt.yticks(range(self.num_classes), self.labels)
        # 显示colorbar
        plt.colorbar()
        plt.xlabel('True Labels')
        plt.ylabel('Predicted Labels')
        plt.title('Confusion matrix')

        # 在图中标注数量/概率信息
        thresh = matrix.max() / 2
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                # 注意这里的matrix[y, x]不是matrix[x, y]
                info = int(matrix[y, x])
                plt.text(x, y, info,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if info > thresh else "black")
        plt.tight_layout()
        plt.show()

def save_checkpoint(state,filename="/home/tlz/GCCS_0916/checkpoint"):
    print("=> Saving checkpoint")
    torch.save(state,filename)

def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
def check_accuracy(loader, model,test,epoch,path=None):
    confusion=ConfusionMatrix(num_classes=2,labels=[0,1])
    num_correct = 0
    num_samples = 0
    model.eval()
    #  需要将numpy作拼接
    # 0921  list 转array  再利用np.append方法拼接tensor ,最后才能调用日sklearn 带的roc函数 ,sklearn自带的roc_curve是需要numpy输入的
    y_sum=np.array(y_sum)
    scores_sum=[]
    scores_sum=np.array(scores_sum)
    with torch.no_grad():
        for x, y in tqdm(loader):
           
            x = x.to(device=device)
            
            y = y.to(device=device)   
            # print(type(y))
            
            scores=model(x)
            #单标签
            # for i in range(len(y)):
            #     if y[i] ==torch.tensor(1.):
            #         print(scores[i])
            
            y_sum=np.append(y_sum,y.to(torch.device('cpu')).numpy())
            scores_sum=np.append(scores_sum,scores.to(torch.device('cpu')).numpy())
            predictions = torch.tensor([0 if i <0.5 else 1 for i in scores]).to(device=device)
            confusion.update(predictions.to("cpu").numpy(),y.to("cpu").numpy())
            
            # 多标签
            # _, predictions = scores.max(1)
            # print("predictions:{}",predictions)
            # num_correct += (predictions == y).sum()
            # print("num_correct:{}",num_correct)
            # num_samples += predictions.size(0)
            # print("num_samples:{}",num_samples)
    if test:     
        fpr, tpr, thresholds = roc_curve(np.array(y_sum), np.array(scores_sum), pos_label=1)
        roc_auc = auc(fpr, tpr)
        plt.figure()
        lw=2 
        plt.plot(fpr, tpr, color='y',
                        lw=lw, label='Original            (AUC = %0.4f)' % roc_auc)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver operating characteristic example')
        plt.legend(loc="lower right")
        ##   !!!!!!!  
        # path = '/home/tlz/GCCS_0916/Data_LVLAX_ALL/roc_curve'
        jpg_name = os.path.join(path,  str(epoch)+ '.jpg')
        plt.savefig(jpg_name)
        plt.close()
    # confusion.plot()
    confusion.summary()
    model.train()
    # return torch.true_divide(num_correct ,num_samples).float()
    return  None
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值