深度学习分割任务

一、网络模型

常见分割网络

1.U-Net 3d

// U-Net3d
class UNet3D(nn.Module):
    def __init__(self, params):
        super(UNet3D, self).__init__()
        self.params = params
        self.in_chns = self.params['in_chns']
        self.ft_chns = self.params['feature_chns']
        self.n_class = self.params['class_num']
        self.trilinear = self.params['trilinear']
        self.dropout = self.params['dropout']
        assert (len(self.ft_chns) == 5 or len(self.ft_chns) == 4)

        self.in_conv = ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0])
        self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1])
        self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2])
        self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3])
        if (len(self.ft_chns) == 5):
            self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4])
            self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3],
                               dropout_p=0.0, trilinear=self.trilinear)
        self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2],
                           dropout_p=0.0, trilinear=self.trilinear)
        self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1],
                           dropout_p=0.0, trilinear=self.trilinear)
        self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0],
                           dropout_p=0.0, trilinear=self.trilinear)

        self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class,
                                  kernel_size=3, padding=1) #bias = -2.19
        if self.params['activation'] == 'sigmoid':
            self.activation =  nn.Sigmoid()
        elif self.params['activation'] == 'softmax':
            self.activation = nn.Softmax()
        else:
            self.activation = 'None'

    def forward(self, x):
        x0 = self.in_conv(x)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        if (len(self.ft_chns) == 5):
            x4 = self.down4(x3)
            x = self.up1(x4, x3)
        else:
            x = x3
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.up4(x, x0)
        output = self.out_conv(x)

        if self.activation == 'None':
            return output
        output_sigmoid = self.activation(output)

        return output_sigmoid

2.U-Net 2d

// parts of unet
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F

""" Parts of the U-Net model """
"""https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py"""


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)
// full unet
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from model.parts_unet import *

""" Full assembly of the parts to form the complete network """
"""Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


if __name__ == '__main__':
    net = UNet(n_channels=3, n_classes=1)
    print(net)

二、损失函数

1.Focal loss

def FocalLoss(predict, soft_y, softmax):
    alpha_num = 0.01
    gamma = 2
    with torch.no_grad():
        alpha = torch.empty_like(predict).fill_(1 - alpha_num)
        alpha[soft_y == 1] = alpha_num

    crit = nn.BCELoss(reduction='none')
    ce_loss = crit(predict.float(), soft_y.float())
    pt = torch.exp(-ce_loss)

    loss = (alpha * torch.pow(1 - pt, gamma) * ce_loss)

    loss = loss.mean()

    return loss

2.Hausdoff loss

def hd_loss(seg, gt, seg_dtm, gt_dtm):
    """
    compute Hausdorff distance loss for binary segmentation based on distance transform
    :param seg: seg results, shape=(b,c,d,h,w) ,for binary c = 1
    :param gt: ground truth, shape=(b,c,d,h,w), for binary c = 1
    :param seg_dtm: segmentation distance transform, shape=(b,c,d,h,w)
    :param gt_dtm: ground truth distance transform, shape=(b,c,d,h,w)
    :return: boundary Hausdorff distance
    """
    delta = (seg - gt) ** 2

    seg_dtm_alpha = seg_dtm ** 2
    gt_dtm_alpha = gt_dtm ** 2

    dtm = seg_dtm_alpha + gt_dtm_alpha
    multiple_d = torch.einsum('bcxyz, bcxyz->bcxyz', dtm, delta)
    hd_loss_value = multiple_d.mean()

    return hd_loss_value

3.Dice loss

def dice_loss(predict, soft_y, softmax=False):

    smooth = 1e-5
    num = predict.size(0)
    p_vol = predict.view(num, -1)
    y_vol = soft_y.view(num, -1)
    intersection = (p_vol * y_vol).sum(1)

    dice_score = (2. * intersection + smooth) / (p_vol.sum(1) + y_vol.sum(1) + smooth)
    return 1 - dice_score.sum() / num

三、训练一个U-Net

1.数据加载
2.模型选择,这里为U-Net
3.损失函数
4.训练
5.预测

//dataset
# !/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
from torch.utils.data import Dataset
import glob
import cv2
import random
import os



'''data loader'''
# inherit class Dataset():
class DataLoader(Dataset):

    def __init__(self, data_path):
        # initialize file paths or a list of file names
        self.data_path = data_path
        self.imgs_path = glob.glob(os.path.join(data_path, 'image/*png'))

    def augment(self, image, flipCode):
        # horizontal flip : flipCode = 1; vertical:=0; h & v : =-1
        flip = cv2.flip(image, flipCode)
        return flip

    def __getitem__(self, index):
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        image_path = self.imgs_path[index]
        label_path = image_path.replace('image', 'label')

        image = cv2.imread(image_path)
        label = cv2.imread(label_path)

        image = image.reshape(1, image.shape[0], image.shape[1])
        label = label.reshape(1, label.shape[0], label.shape[1])

        if label.max() > 1:
            label = label / 255

        # random flip the image & label if flipCode != 2
        flipCode = random.chioce([-1, 0, 1, 2])
        if flipCode != 2:
            image = self.augment(image, flipCode)
            label = self.augment(label, flipCode)
        return image, label

    def __len__(self):
        # get the len of train set
        return len(self.imgs_path)

if __name__ == '__main__':
    liunj_dataset = DataLoader('data/train/')
    print('number of train data:', len(liunj_dataset))

    train_loader = torch.utils.data.DataLoader(dataset=DataLoader,
                                           batch_size=2,
                                           shuffle=True)
    for image, label in train_loader:
        print(image.shape)
//train_model
# -*- coding: utf-8 -*-
from model.full_unet import UNet
from dataset import DataLoader
from torch import optim
import torch.nn as nn
import torch
import time
import matplotlib.pyplot as plt


def train(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
    start = time.time()

    liunj_dataset = DataLoader(data_path)
    train_loader = torch.utils.data.DataLoader(dataset=liunj_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    # define the optimizer
    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8,momentum=0.9)

    # define the loss function
    criterion = nn.BCEWithLogitsLoss()


    # initial the loss +wuqiong
    best_loss = float('inf')


    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch + 1, epochs))
        print('-' * 50)

        # train model
        net.train()

        for image, label in train_loader:

            optimizer.zero_grad()

            image = image.to(device=device, dtype=torch.float32)
            label = label.to(device=device, dtype=torch.float32)

            pred = net(image)

            loss = criterion(pred, label)
            print('Loss/train', loss.item())

            if loss < best_loss:
                best_loss = loss
                torch.save(net.state_dict(), 'best_model.pth')

            loss.backward()
            optimizer.step()



        #     if step % 10 == 0:
        #         print('Epoch {}/{} | Current step: {} | Loss: {} | Acc: {} | AllocMem (Mb): {}' \
        #               .format(epoch + 1, epochs, step, loss, acc, torch.cuda.memory_allocated()/1024/1024)
        #               )
        #     # current step: {} / {}
        #     epoch_loss = running_loss / len(dataloader.dataset)
        #     epoch_acc = running_acc / len(dataloader.dataset)
        #
        #     train_loss.append(epoch_loss) if phase == 'Train' else valid_loss.append(epoch_loss)
        #
        # time_elapsed = time.time() - start
        # print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))




# def acc_metric(predb, yb):
#     return (predb.argmax(dim=1) == yb.cuba()).float().mean()


if __name__ == "__main_":

    device = torch.device('cuda' if torch.cuda.is_availabel() else 'cpu')

    net = UNet(n_channels=3, n_classes=3)
    net.to(device=device)

    data_path = 'data/train/'

    train(net, device, data_path)
	# 此段存在问题,后续可以使用tensorboard可视化训练过程
    plt.figure(figsize=(12, 4))
    plt.subplot(121)
    plt.plot(loss[:])
    plt.title("train_loss")
    plt.subplot(122)
    plt.plot(train_epochs_loss[1:], '-o', label="train_loss")
    plt.plot(valid_epochs_loss[1:], '-o', label="valid_loss")
    plt.title("epochs_loss")
    plt.legend()
    plt.show()
# -*- coding: utf-8 -*-
import glob
import numpy as np
import torch
import cv2
from model.full_unet import UNet


if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    net = UNet(n_channels=3, n_classes=2)
    net.to(device=device)

    net.load_state_dict(torch.load('best_model.pth', map_location=device))

    # test model
    net.eval()

    # get all the images: '../../...png'
    test_path = glob.glob('data/test/*.png')

    for test_path in test_path:

        save_res_path = test_path.split('.')[0] + '_res.png'

        img = cv2.imread(test_path)
        img = cv2.reshape(1, 1, img.shape[0], img.shape[1])
        # change to tensor
        img_tensor = torch.from_numpy(img)
        # copy the img_tensor to device
        img_tensor = img_tensor.to(device=device, dtype=torch.float32)

        pred = net(img_tensor)

        pred = np.array(pred.data.cpu()[0][0])

        pred[pred >= 0.5] = 255
        pred[pred < 0.5] = 0

        cv2.imwrite(save_res_path, pred)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Naijiaaa

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值