unet测试评估metric脚本

该代码实现了将PaddleSeg的UNet模型转换为PyTorch并进行评估。主要功能包括计算accuracy、classprecision、classrecall和kappa指标。通过定义calculate_area函数来计算交集、预测和标签区域,然后使用mean_iou、class_measurement和kappa等函数计算评估指标。在训练网络后,对验证集进行评估并打印结果。
摘要由CSDN通过智能技术生成

全部复制的paddleseg的代码转torch

import argparse
import logging
import os

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

from utils.data_loading import BasicDataset
from unet import UNet
from utils.utils import plot_img_and_mask
from torch.utils.data import DataLoader, random_split
from utils.data_loading import BasicDataset, CarvanaDataset
from tqdm import tqdm
import torch.nn.functional as F


# 使用python写一个评估使用pytorch训练的unet模型的好坏,模型输出nchw格式的数据,真实标签数据为nhw格式,请计算模型的accuracy, calss precision ,class recall,kappa指标

EPSILON = 1e-32

def calculate_area(pred, label, num_classes, ignore_index=255):
    """
    Calculate intersect, prediction and label area

    Args:
        pred (Tensor): The prediction by model.
        label (Tensor): The ground truth of image.
        num_classes (int): The unique number of target classes.
        ignore_index (int): Specifies a target value that is ignored. Default: 255.

    Returns:
        Tensor: The intersection area of prediction and the ground on all class.
        Tensor: The prediction area on all class.
        Tensor: The ground truth area on all class
    """
    if len(pred.shape) == 4:
        pred = torch.squeeze(pred, axis=1)
    if len(label.shape) == 4:
        label = torch.squeeze(label, axis=1)
    if not pred.shape == label.shape:
        
        raise ValueError('Shape of `pred` and `label should be equal, '
                         'but there are {} and {}.'.format(pred.shape,
                                                           label.shape))
    pred_area = []
    label_area = []
    intersect_area = []
    mask = label != ignore_index

    for i in range(num_classes):
        pred_i = torch.logical_and(pred == i, mask)
        label_i = label == i
        intersect_i = torch.logical_and(pred_i, label_i)
        pred_area.append(torch.sum(pred_i))  
        label_area.append(torch.sum(label_i))  
        intersect_area.append(torch.sum(intersect_i))  

    pred_area = torch.stack(pred_area)  
    label_area = torch.stack(label_area)  
    intersect_area = torch.stack(intersect_area)  

    return intersect_area, pred_area, label_area


def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
    parser.add_argument('--root', '-r', type=str, default=False, help='root dir')
    parser.add_argument('--num', '-n', type=int, default=False, help='num of classes')

    return parser.parse_args()


dir_img_path = 'imgs'
dir_mask_path = 'masks'

import metrics

def train_net(net,
              device,
              epochs: int = 5,
              batch_size: int = 1,
              learning_rate: float = 0.001,
              val_percent: float = 0.1,
              save_checkpoint: bool = True,
              img_scale: float = 0.5,
              amp: bool = False,root_dir: str = '/data/yangbo/unet/datas/data1'):

    train_dir_img=os.path.join(root_dir,'train',dir_img_path)
    train_dir_mask=os.path.join(root_dir,'train',dir_mask_path)

    val_dir_img=os.path.join(root_dir,'val',dir_img_path)
    val_dir_mask=os.path.join(root_dir,'val',dir_mask_path)
    # 1. Create dataset
    try:
        train_dataset = CarvanaDataset(train_dir_img, train_dir_mask, img_scale)
        val_dataset = CarvanaDataset(val_dir_img, val_dir_mask, img_scale)
    except (AssertionError, RuntimeError):
        train_dataset = BasicDataset(train_dir_img, train_dir_mask, img_scale)
        val_dataset = BasicDataset(val_dir_img, val_dir_mask, img_scale)

    n_val = len(val_dataset)
    n_train = len(train_dataset)

    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
    train_loader = DataLoader(train_dataset, shuffle=True, **loader_args)
    val_loader = DataLoader(val_dataset, shuffle=False, drop_last=True, **loader_args)


    # (Initialize logging)

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_checkpoint}
        Device:          {device.type}
        Images scaling:  {img_scale}
        Mixed Precision: {amp}
    ''')

    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    #optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
    #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score

    # 5. Begin training
    intersect_area_all=torch.zeros([1])
    pred_area_all=torch.zeros([1])
    label_area_all=torch.zeros([1])
    for idx,batch in tqdm(enumerate(val_loader)):
        images = batch['image']
        true_masks = batch['mask']

        assert images.shape[1] == net.n_channels, \
            f'Network has been defined with {net.n_channels} input channels, ' \
            f'but loaded images have {images.shape[1]} channels. Please check that ' \
            'the images are loaded correctly.'

        images = images.to(device=device, dtype=torch.float32)
        true_masks = true_masks.to(device=device, dtype=torch.long)
        with torch.no_grad():
            masks_pred = net(images)
            masks_pred=torch.argmax(masks_pred,axis=1,keepdim=True)
            intersect_area, pred_area, label_area=calculate_area(masks_pred,true_masks,3)
            intersect_area_all = intersect_area_all + intersect_area
            pred_area_all = pred_area_all + pred_area
            label_area_all = label_area_all + label_area
    metrics_input = (intersect_area_all, pred_area_all, label_area_all)
    class_iou, miou = metrics.mean_iou(*metrics_input)
    acc, class_precision, class_recall = metrics.class_measurement(
        *metrics_input)
    kappa = metrics.kappa(*metrics_input)
    class_dice, mdice = metrics.dice(*metrics_input)
    infor="[EVAL] #Images: {} mIoU: {:.4f} Acc: {:.4f} Kappa: {:.4f} Dice: {:.4f}".format(
            len(val_loader), miou, acc, kappa, mdice)
    print(infor)
    print("[EVAL] Class IoU: " + str(np.round(class_iou, 4)))
    print("[EVAL] Class Precision: " + str(
            np.round(class_precision, 4)))
    print("[EVAL] Class Recall: " + str(np.round(class_recall, 4)))

if __name__ == '__main__':
    args = get_args()

    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    # Change here to adapt to your data
    # n_channels=3 for RGB images
    # n_classes is the number of probabilities you want to get per pixel
    # 修改numclass
    net = UNet(n_channels=3, n_classes=args.num, bilinear=True)
    net.eval()
    logging.info(f'Network:\n'
                 f'\t{net.n_channels} input channels\n'
                 f'\t{net.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

    if args.load:
        net.load_state_dict(torch.load(args.load, map_location=device))
        logging.info(f'Model loaded from {args.load}')

    net.to(device=device)
    try:
        train_net(net=net,
                  epochs=0,
                  batch_size=args.batch_size,
                  learning_rate=0,
                  device=device,
                  img_scale=args.scale,
                  val_percent=args.val / 100,
                  amp=args.amp,
                  root_dir=args.root)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')


metris.py

import numpy as np
import torch
import sklearn.metrics as skmetrics

def mean_iou(intersect_area, pred_area, label_area):
    """
    Calculate iou.

    Args:
        intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.
        pred_area (Tensor): The prediction area on all classes.
        label_area (Tensor): The ground truth area on all classes.

    Returns:
        np.ndarray: iou on all classes.
        float: mean iou of all classes.
    """
    intersect_area = intersect_area.numpy()
    pred_area = pred_area.numpy()
    label_area = label_area.numpy()
    union = pred_area + label_area - intersect_area
    class_iou = []
    for i in range(len(intersect_area)):
        if union[i] == 0:
            iou = 0
        else:
            iou = intersect_area[i] / union[i]
        class_iou.append(iou)
    miou = np.mean(class_iou)
    return np.array(class_iou), miou

def class_measurement(intersect_area, pred_area, label_area):
    """
    Calculate accuracy, calss precision and class recall.

    Args:
        intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.
        pred_area (Tensor): The prediction area on all classes.
        label_area (Tensor): The ground truth area on all classes.

    Returns:
        float: The mean accuracy.
        np.ndarray: The precision of all classes.
        np.ndarray: The recall of all classes.
    """
    intersect_area = intersect_area.numpy()
    pred_area = pred_area.numpy()
    label_area = label_area.numpy()

    mean_acc = np.sum(intersect_area) / np.sum(pred_area)
    class_precision = []
    class_recall = []
    for i in range(len(intersect_area)):
        precision = 0 if pred_area[i] == 0 \
            else intersect_area[i] / pred_area[i]
        recall = 0 if label_area[i] == 0 \
            else intersect_area[i] / label_area[i]
        class_precision.append(precision)
        class_recall.append(recall)

    return mean_acc, np.array(class_precision), np.array(class_recall)

def kappa(intersect_area, pred_area, label_area):
    """
    Calculate kappa coefficient

    Args:
        intersect_area (Tensor): The intersection area of prediction and ground truth on all classes..
        pred_area (Tensor): The prediction area on all classes.
        label_area (Tensor): The ground truth area on all classes.

    Returns:
        float: kappa coefficient.
    """
    intersect_area = intersect_area.numpy().astype(np.float64)
    pred_area = pred_area.numpy().astype(np.float64)
    label_area = label_area.numpy().astype(np.float64)
    total_area = np.sum(label_area)
    po = np.sum(intersect_area) / total_area
    pe = np.sum(pred_area * label_area) / (total_area * total_area)
    kappa = (po - pe) / (1 - pe)
    return kappa

def dice(intersect_area, pred_area, label_area):
    """
    Calculate DICE.

    Args:
        intersect_area (Tensor): The intersection area of prediction and ground truth on all classes.
        pred_area (Tensor): The prediction area on all classes.
        label_area (Tensor): The ground truth area on all classes.

    Returns:
        np.ndarray: DICE on all classes.
        float: mean DICE of all classes.
    """
    intersect_area = intersect_area.numpy()
    pred_area = pred_area.numpy()
    label_area = label_area.numpy()
    union = pred_area + label_area
    class_dice = []
    for i in range(len(intersect_area)):
        if union[i] == 0:
            dice = 0
        else:
            dice = (2 * intersect_area[i]) / union[i]
        class_dice.append(dice)
    mdice = np.mean(class_dice)
    return np.array(class_dice), mdice

使用示例

python .\test2.py --root D:\pic\23\0403\851-1003339-H01\bend --scale 0.25 --load C:\Users\Admin\Desktop\fsdownload\checkpoint_epoch485.pth --num 3

结果展示

[EVAL] #Images: 74 mIoU: 0.5119 Acc: 0.9996 Kappa: 0.4405 Dice: 0.6002
[EVAL] Class IoU: [0.9997 0.4177 0.1183]
[EVAL] Class Precision: [0.9998 0.6767 0.1858]
[EVAL] Class Recall: [0.9998 0.5219 0.2456]

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值