Unet多类别分割

首先是制作数据集,也就是创建一个data.py

import torch
import torch.utils.data as data
from torch.autograd import Variable as V
from PIL import Image

import cv2
import numpy as np
import os
import scipy.misc as misc


# def randomHueSaturationValue(image, hue_shift_limit=(-180, 180),
#                              sat_shift_limit=(-255, 255),
#                              val_shift_limit=(-255, 255), u=0.5):
#     if np.random.random() < u:
#         image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
#         h, s, v = cv2.split(image)
#         hue_shift = np.random.randint(hue_shift_limit[0], hue_shift_limit[1] + 1)
#         hue_shift = np.uint8(hue_shift)
#         h += hue_shift
#         sat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1])
#         s = cv2.add(s, sat_shift)
#         val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1])
#         v = cv2.add(v, val_shift)
#         image = cv2.merge((h, s, v))
#         # image = cv2.merge((s, v))
#         image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
#
#     return image
#
#
# def randomShiftScaleRotate(image, mask,
#                            shift_limit=(-0.0, 0.0),
#                            scale_limit=(-0.0, 0.0),
#                            rotate_limit=(-0.0, 0.0),
#                            aspect_limit=(-0.0, 0.0),
#                            borderMode=cv2.BORDER_CONSTANT, u=0.5):
#     if np.random.random() < u:
#         height, width, channel = image.shape
#
#         angle = np.random.uniform(rotate_limit[0], rotate_limit[1])
#         scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1])
#         aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1])
#         sx = scale * aspect / (aspect ** 0.5)
#         sy = scale / (aspect ** 0.5)
#         dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width)
#         dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height)
#
#         cc = np.math.cos(angle / 180 * np.math.pi) * sx
#         ss = np.math.sin(angle / 180 * np.math.pi) * sy
#         rotate_matrix = np.array([[cc, -ss], [ss, cc]])
#
#         box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ])
#         box1 = box0 - np.array([width / 2, height / 2])
#         box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy])
#
#         box0 = box0.astype(np.float32)
#         box1 = box1.astype(np.float32)
#         mat = cv2.getPerspectiveTransform(box0, box1)
#         image = cv2.warpPerspective(image, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
#                                     borderValue=(
#                                         0, 0,
#                                         0,))
#         mask = cv2.warpPerspective(mask, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
#                                    borderValue=(
#                                        0, 0,
#                                        0,))
#
#     return image, mask
#
#
# def randomHorizontalFlip(image, mask, u=0.5):
#     if np.random.random() < u:
#         image = cv2.flip(image, 1)
#         mask = cv2.flip(mask, 1)
#
#     return image, mask
#
#
# def randomVerticleFlip(image, mask, u=0.5):
#     if np.random.random() < u:
#         image = cv2.flip(image, 0)
#         mask = cv2.flip(mask, 0)
#
#     return image, mask
#
#
# def randomRotate90(image, mask, u=0.5):
#     if np.random.random() < u:
#         image = np.rot90(image)
#         mask = np.rot90(mask)
#
#     return image, mask
#
#
# def default_loader(img_path, mask_path):
#     img = cv2.imread(img_path)
#     # print("img:{}".format(np.shape(img)))
#     img = cv2.resize(img, (448, 448))
#
#     mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
#
#     mask = 255. - cv2.resize(mask, (448, 448))
#
#     img = randomHueSaturationValue(img,
#                                    hue_shift_limit=(-30, 30),
#                                    sat_shift_limit=(-5, 5),
#                                    val_shift_limit=(-15, 15))
#
#     img, mask = randomShiftScaleRotate(img, mask,
#                                        shift_limit=(-0.1, 0.1),
#                                        scale_limit=(-0.1, 0.1),
#                                        aspect_limit=(-0.1, 0.1),
#                                        rotate_limit=(-0, 0))
#     img, mask = randomHorizontalFlip(img, mask)
#     img, mask = randomVerticleFlip(img, mask)
#     img, mask = randomRotate90(img, mask)
#
#     mask = np.expand_dims(mask, axis=2)
#     #
#     # print(np.shape(img))
#     # print(np.shape(mask))
#
#     img = np.array(img, np.float32).transpose(2, 0, 1) / 255.0 * 3.2 - 1.6
#     mask = np.array(mask, np.float32).transpose(2, 0, 1) / 255.0
#     mask[mask >= 0.5] = 1
#     mask[mask <= 0.5] = 0
#     # mask = abs(mask-1)
#     return img, mask


def read_own_data(root_path, mode):
    images = []
    masks = []

    image_root = os.path.join(root_path, mode + '/images')
    gt_root = os.path.join(root_path, mode + '/mask')

    for image_name in os.listdir(gt_root):
        image_path = os.path.join(image_root, image_name)
        label_path = os.path.join(gt_root, image_name)

        images.append(image_path)
        masks.append(label_path)

    return images, masks


def own_data_loader(img_path, mask_path):
    img = cv2.imread(img_path)
    img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_NEAREST)
    mask = cv2.imread(mask_path, 0)
    mask = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST)

    # img = randomHueSaturationValue(img,
    #                                hue_shift_limit=(-30, 30),
    #                                sat_shift_limit=(-5, 5),
    #                                val_shift_limit=(-15, 15))
    #
    # img, mask = randomShiftScaleRotate(img, mask,
    #                                    shift_limit=(-0.1, 0.1),
    #                                    scale_limit=(-0.1, 0.1),
    #                                    aspect_limit=(-0.1, 0.1),
    #                                    rotate_limit=(-0, 0))
    # img, mask = randomHorizontalFlip(img, mask)
    # img, mask = randomVerticleFlip(img, mask)
    # img, mask = randomRotate90(img, mask)

    mask = np.expand_dims(mask, axis=2)

    img = np.array(img, np.float32) / 255.0 * 3.2 - 1.6
    mask = np.array(mask, np.float32)

    img = np.array(img, np.float32).transpose(2, 0, 1)
    mask = np.array(mask, np.float32).transpose(2, 0, 1)
    return img, mask


def own_data_test_loader(img_path, mask_path):
    img = cv2.imread(img_path)
    img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_NEAREST)
    mask = cv2.imread(mask_path, 0)
    mask = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST)
    mask = np.expand_dims(mask, axis=2)

    img = np.array(img, np.float32) / 255.0 * 3.2 - 1.6
    mask = np.array(mask, np.float32)
    # mask[mask >= 0.5] = 1
    # mask[mask < 0.5] = 0

    img = np.array(img, np.float32).transpose(2, 0, 1)
    mask = np.array(mask, np.float32).transpose(2, 0, 1)

    return img, mask


class ImageFolder(data.Dataset):

    def __init__(self, root_path, mode='train'):
        self.root = root_path
        self.mode = mode
        self.images, self.labels = read_own_data(self.root, self.mode)



    def __getitem__(self, index):
        if self.mode == 'test':
            img, mask = own_data_test_loader(self.images[index], self.labels[index])
        else:
            img, mask = own_data_loader(self.images[index], self.labels[index])
            img = torch.Tensor(img)
            mask = torch.Tensor(mask)
        return img, mask

    def __len__(self):
        assert len(self.images) == len(self.labels), 'The number of images must be equal to labels'
        return len(self.images)

然后是训练脚本,训练脚本除了正常的环境之外,再额外安装segmentation_models_pytorch和pytorch_toolbelt两个包就行,train.py,n_classes设置自己的类别数,比如我的是0,1,2,3我就设置4

import time
import warnings
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.swa_utils import AveragedModel, SWALR
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss, SoftCrossEntropyLoss, LovaszLoss
from pytorch_toolbelt import losses as L
from data import ImageFolder
from sklearn import metrics

warnings.filterwarnings('ignore')
torch.backends.cudnn.enabled = True

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
n_classes = 4


def cal_cm(y_true, y_pred):
    y_true = y_true.reshape(1, -1).squeeze()
    y_pred = y_pred.reshape(1, -1).squeeze()
    cm = metrics.confusion_matrix(y_true, y_pred)
    return cm


def iou_mean(pred, target, n_classes=n_classes):
    # n_classes :the number of classes in your dataset,not including background
    # for mask and ground-truth label, not probability map
    ious = []
    iousSum = 0
    # pred = torch.from_numpy(pred)
    pred = pred.view(-1)
    # print(type(pred))
    target = np.array(target.cpu())
    target = torch.from_numpy(target)
    # print(type(target))
    target = target.view(-1)

    # Ignore IoU for background class ("0")
    for cls in range(1, n_classes):  # This goes from 1:n_classes-1 -> class "0" is ignored
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = (pred_inds[target_inds]).long().sum().data.cpu().item()  # Cast to long to prevent overflows
        union = pred_inds.long().sum().data.cpu().item() + target_inds.long().sum().data.cpu().item() - intersection
        if union == 0:
            ious.append(float('nan'))  # If there is no ground truth, do not include in evaluation
        else:
            ious.append(float(intersection) / float(max(union, 1)))
            iousSum += float(intersection) / float(max(union, 1))
    return iousSum / n_classes


def multi_acc(pred, label):
    probs = torch.log_softmax(pred, dim=1)
    _, tags = torch.max(probs, dim=1)
    corrects = torch.eq(tags, label).int()
    acc = corrects.sum() / corrects.numel()
    return acc


def train(EPOCHES, BATCH_SIZE, data_root, channels, optimizer_name,
          model_path, swa_model_path, loss, early_stop):
    train_dataset = ImageFolder(data_root, mode='train')
    # print(len(train_dataset))
    val_dataset = ImageFolder(data_root, mode='val')

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0)

    val_data_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0)

    # 定义模型,优化器,损失函数
    # model = smp.UnetPlusPlus(
    #         encoder_name="efficientnet-b7",
    #         encoder_weights="imagenet",
    #         in_channels=channels,
    #         classes=17,
    # )
    # model = smp.UnetPlusPlus(
    #         encoder_name="timm-resnest101e",
    #         encoder_weights="imagenet",
    #         in_channels=channels,
    #         classes=2,
    # )

    model = smp.Unet(
        encoder_name="resnext50_32x4d",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7, resnet34
        encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
        in_channels=channels,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=n_classes,  # model output channels (number of classes in your dataset)
        activation='softmax',  # 二分类需要换成sigmoid
    )

    model.to(DEVICE)
    # 加载预模型可以打开下面这句,model_path给预模型路径
    # model.load_state_dict(torch.load(model_path))
    if (optimizer_name == "sgd"):
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=1e-4, weight_decay=1e-3, momentum=0.9)
    else:
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=1e-3, weight_decay=1e-3)
    # 余弦退火调整学习率
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=2,  # T_0就是初始restart的epoch数目
        T_mult=2,  # T_mult就是重启之后因子,即每个restart后,T_0 = T_0 * T_mult
        eta_min=1e-5  # 最低学习率
    )

    if (loss == "SoftCE_dice"):  # mode: Loss mode 'binary', 'multiclass' or 'multilabel'
        # 损失函数采用SoftCrossEntropyLoss+DiceLoss
        # diceloss在一定程度上可以缓解类别不平衡,但是训练容易不稳定
        # DiceLoss_fn = DiceLoss(mode='binary')
        DiceLoss_fn = DiceLoss(mode='multiclass')  # 多分类改为multiclass
        # Bceloss_fn = nn.BCELoss()
        # 软交叉熵,即使用了标签平滑的交叉熵,会增加泛化性
        SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)  # 用于多分类
        loss_fn = L.JointLoss(first=DiceLoss_fn, second=SoftCrossEntropy_fn, first_weight=0.8, second_weight=0.2).cuda()
    # loss_fn = smp.utils.losses.DiceLoss()
    else:
        # 损失函数采用SoftCrossEntropyLoss+LovaszLoss
        # LovaszLoss是对基于子模块损失凸Lovasz扩展的mIoU损失的直接优化
        # LovaszLoss_fn = LovaszLoss(mode='binary')
        LovaszLoss_fn = LovaszLoss(mode='multiclass')
        # 软交叉熵,即使用了标签平滑的交叉熵,会增加泛化性
        SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)  # 这里我没有改,这里是多分类的,有需求就改下
        loss_fn = L.JointLoss(first=LovaszLoss_fn, second=SoftCrossEntropy_fn,
                              first_weight=0.5, second_weight=0.5).cuda()

    best_miou = 0
    best_miou_epoch = 0
    train_loss_epochs, val_mIoU_epochs, lr_epochs = [], [], []
    for epoch in range(1, EPOCHES + 1):
        losses = []
        start_time = time.time()
        model.train()
        for image, target in tqdm(train_data_loader, ncols=20, total=len(train_data_loader)):

            image, target = image.to(DEVICE), target.to(DEVICE)
            output = model(image)

            target = torch.tensor(target, dtype=torch.int64)
            loss = loss_fn(output, target)
            losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        scheduler.step()

        val_acc = []
        val_iou = []
        val_data_loader_num = iter(val_data_loader)
        for val_img, val_mask in tqdm(val_data_loader_num, ncols=20, total=len(val_data_loader_num)):
            val_img, label = val_img.to(DEVICE), val_mask.to(DEVICE)
            predict = model(val_img)
            label = label.squeeze(1)

            acc = multi_acc(predict, label)
            val_acc.append(acc.item())

            predict = torch.argmax(predict, axis=1)
            iou = iou_mean(predict, label, n_classes)
            val_iou.append(iou)

        train_loss_epochs.append(np.array(losses).mean())
        val_mIoU_epochs.append(np.mean(val_iou))
        lr_epochs.append(optimizer.param_groups[0]['lr'])

        print('Epoch:' + str(epoch) + ' Loss:' + str(np.array(losses).mean()) + ' Val_Acc:' + str(
            np.array(val_acc).mean()) + ' Val_IOU:' + str(np.mean(val_iou)) + ' Time_use:' + str(
            (time.time() - start_time) / 60.0))

        if best_miou < np.stack(val_iou).mean(0).mean():
            best_miou = np.stack(val_iou).mean(0).mean()
            best_miou_epoch = epoch
            torch.save(model.state_dict(), model_path)
            print("  valid mIoU is improved. the model is saved.")
        else:
            print("")
            if (epoch - best_miou_epoch) >= early_stop:
                break

    return train_loss_epochs, val_mIoU_epochs, lr_epochs


if __name__ == '__main__':
    EPOCHES = 200
    BATCH_SIZE = 2
    loss = "SoftCE_dice"
    # loss = "SoftCE_Lovasz"
    channels = 3
    optimizer_name = "adamw"

    data_root = "F:/CT_lung_seg_and_class/seg_data/CC-CCI_clean/"
    model_path = "./weights/CamVid_" + loss + '.pth'
    swa_model_path = model_path + "_swa.pth"
    early_stop = 400
    train_loss_epochs, val_mIoU_epochs, lr_epochs = train(EPOCHES, BATCH_SIZE, data_root, channels, optimizer_name,
                                                          model_path, swa_model_path, loss, early_stop)

    if (True):
        import matplotlib.pyplot as plt

        epochs = range(1, len(train_loss_epochs) + 1)
        plt.plot(epochs, train_loss_epochs, 'r', label='train loss')
        plt.plot(epochs, val_mIoU_epochs, 'b', label='val mIoU')
        plt.title('train loss and val mIoU')
        plt.legend()
        plt.savefig("train loss and val mIoU.png", dpi=300)
        plt.figure()
        plt.plot(epochs, lr_epochs, 'r', label='learning rate')
        plt.title('learning rate')
        plt.legend()
        plt.savefig("learning rate.png", dpi=300)
        plt.show()

 

最后predict,注意训练的时候还得创建一个weights文件夹存放pth

import os
import glob
import time
import cv2
import numpy as np
import torch
import segmentation_models_pytorch as smp
from torch.optim.swa_utils import AveragedModel
from data import ImageFolder

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'


def test_1(channels, model_path, output_dir, test_path):
    # model = smp.UnetPlusPlus(
    #         encoder_name="resnet101",
    #         encoder_weights="imagenet",
    #         in_channels=4,
    #         classes=10,
    # )
    # model = smp.DeepLabV3Plus(
    #         encoder_name="resnet101",
    #         encoder_weights="imagenet",
    #         in_channels=in_channels,
    #         classes=1,
    # )

    model = smp.Unet(
        encoder_name="resnext50_32x4d",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights="imagenet",  # use `imagenet` pre-trained weights for encoder initialization
        in_channels=channels,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=12,  # model output channels (number of classes in your dataset)
        activation='softmax',
    )

    # 如果模型是SWA
    if ("swa" in model_path):
        model = AveragedModel(model)
    model.to(DEVICE);
    model.load_state_dict(torch.load(model_path))
    model.eval()

    im_names = os.listdir(test_path)
    for name in im_names:
        full_path = os.path.join(test_path, name)
        img = cv2.imread(full_path)
        h, w, c = img.shape
        # resize是因为训练的输入我resize成了512,后面有还原
        img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_NEAREST)
        image = np.array(img, np.float32) / 255.0 * 3.2 - 1.6
        image = np.array(image, np.float32).transpose(2, 0, 1)
        image = np.expand_dims(image, axis=0)
        image = torch.Tensor(image)
        image = image.cuda()
        output = model(image)
        output = torch.argmax(output, axis=1).cpu().data.numpy()
        output = output.squeeze()
        output = cv2.resize(output, (w, h), interpolation=cv2.INTER_NEAREST)
        save_path = os.path.join(output_dir, name)
        cv2.imwrite(save_path, output)


if __name__ == "__main__":
    data_root = "./data/CamVid/test/"
    model_path = "./weights/CamVid_SoftCE_dice.pth"
    output_dir = './data/CamVid/test_pre/'

    test_1(3, model_path, output_dir, data_root)

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
UNET是一种常用的深度学习模型,用于图像语义分割任务。在语义分割中,我们的目标是将图像中的每个像素点进行分类,并将其标记为属于不同的类别。相比于传统的图像分类任务,语义分割要求模型能够对每个像素点进行像素级别的分类。 UNET模型的核心思想是将图像进行编码和解码两个过程,以捕捉不同尺度的特征信息。在编码阶段,UNET通过使用多个卷积层和池化层来提取图像的低级特征,并逐渐将其转换为高级语义特征。在解码阶段,UNET通过使用反卷积层和跳跃连接来重建特征图,以精细化地进行像素级别的分类。 对于UNET用于语义分割多个类别的情况,我们需要进行一些适当的调整。通常,我们将输出层的通道数设置为类别数加一,其中一层用于背景类别的分类,其余通道用于其他各个类别的分类。这样,模型就能够对图像中的每个像素点进行多类别分类,并将其标记为不同的类别。 为了训练UNET模型进行语义分割多个类别,我们可以使用交叉熵损失函数来度量模型的输出与真实标签的差距。同时,为了提高模型的性能,我们可以采取一些技巧,例如数据增强、迁移学习和模型集成等。 总而言之,UNET是一种适用于语义分割任务的深度学习模型,它可以用于对图像中的多个类别进行像素级别的分类。通过合适的调整和训练,我们可以利用UNET模型实现准确而有效的图像语义分割类别

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值