从无到有完成 Lizard 细胞分割

★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>

项目介绍

(本项目可一键运行,第一次需要解开数据集解压代码)

本项目的目的是做一个小白易懂的实例分割入门项目。

目前Ai Studio上从0到无写一个分割项目的很少,或者说不够清晰,大部分都是借助paddleseg实现,本文将以清晰易懂的代码完成Lizard细胞分割。

本项目采用的数据集由下面论文提供:(ICCV2021)

数据集介绍

数据集已挂载在项目中

数据集包含以20倍物镜放大率从结肠全幻灯片图像中提取的图像区域及其相应的标签。

每个标签提供:

  • 实例分割图
  • 每个核的类别
  • 每个核的边界框
  • 每个核的质心

作为数据集的一部分提供的核类别如下:

  1. 第1类:中性粒细胞
  2. 第2类:上皮
  3. 第3类:淋巴细胞
  4. 第4类:血浆
  5. 第5类:中性细胞
  6. 第6类:结缔组织

注意: 本文只做实例分割,用不到具体分类信息

留个坑,后面写一篇语义分割

每个标签都存储为一个文件。我们提供了一些示例代码(read_label.py),以提供有关如何读取标签的信息。

如果使用此数据集的任何部分,必须引用以下出版物,该出版物更详细地描述了数据集:

Graham, Simon, et al. “Lizard: A Large-Scale Dataset for Colonic Nuclear Instance Segmentation and Classification.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021.

开干

大致的流程得先搞清楚:

  1. 数据处理 (dataset类)
  2. 模型选择和搭建 (Net类)
  3. 训练

1 数据处理

先解压数据,可以看到数据集中给出了标签的使用方法

/home/aistudio/labels/read_label.py

也可以打开下面文件,更清楚的查看数据格式

/home/aistudio/labels/read_label.ipynb

自行打开上面的文件进行查看。

因为我们要做实例分割,那么我们训练的时候就需要有原始图像对应的分割的标签图,恰恰就对应了将图片中细胞变为白色的图片。

也就是我们之后的数据集类和网络训练的输入输出如下:

# 解压数据集 可以在左侧看到images和labels 只需要执行一次
# !unzip /home/aistudio/data/data199133/Lizard.zip -d ./

import warnings

# 忽略特定类型的警告
warnings.filterwarnings("ignore", category=DeprecationWarning)
# 写一个数据预处理函数
import os
import cv2
from PIL import Image
import numpy as np
import paddle
from paddle.io import Dataset
import matplotlib.pyplot as plt
import scipy.io as sio
import random

random_seed = 123456
paddle.seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)

class RandomRotation:
    def __init__(self, angle_range):
        self.angle_range = angle_range

    def __call__(self, img, inst_map):
        angle = random.uniform(*self.angle_range)
        img = img.rotate(angle, resample=Image.BILINEAR)
        inst_map = Image.fromarray(inst_map)
        inst_map = inst_map.rotate(angle, resample=Image.NEAREST)
        inst_map = np.array(inst_map)
        return img, inst_map

class RandomHorizontalFlip:
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img, inst_map):
        if random.random() < self.p:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            inst_map = np.fliplr(inst_map)
        return img, inst_map

class RandomScale:
    def __init__(self, scale_range):
        self.scale_range = scale_range

    def __call__(self, img, inst_map):
        scale = random.uniform(*self.scale_range)
        w, h = img.size
        new_w, new_h = int(w * scale), int(h * scale)
        img = img.resize((new_w, new_h), resample=Image.BILINEAR)
        inst_map = Image.fromarray(inst_map)
        inst_map = inst_map.resize((new_w, new_h), resample=Image.NEAREST)
        inst_map = np.array(inst_map)
        return img, inst_map

class Resize:
    def __init__(self, size=(512,512)):
        self.size = size
    
    def __call__(self, img, inst_map):
        img = img.resize(self.size, Image.BILINEAR)
        inst_map = cv2.resize(inst_map, self.size, interpolation=cv2.INTER_NEAREST)
        return img, inst_map

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
# dataset类

# 划分数据集
def split_data(img_list, label_list, ratio=0.8):
    train_img = img_list[:int(len(img_list)*ratio)]
    train_label = label_list[:int(len(label_list)*ratio)]
    test_img = img_list[int(len(img_list)*ratio):]
    test_label = label_list[int(len(label_list)*ratio):]

    return train_img, train_label, test_img, test_label


# 构建数据类
class LizardDataset(Dataset):
    def __init__(self, root: str, train: bool, transforms=None, ratio=0.8):
        super(LizardDataset, self).__init__()

        self.ratio = ratio

        data_root = root
        assert os.path.exists(
            data_root), f"path '{data_root}' does not exists."
        self.transforms = transforms
        img_names = [i for i in os.listdir(os.path.join(
            data_root, "images")) if i.endswith(".png")]

        # 打乱顺序
        random.shuffle(img_names)

        self.img_list = [os.path.join(data_root, "images", i)
                         for i in img_names]
        # label names
        self.label_list = [os.path.join(data_root, "labels", "Labels", i.split(".")[0] + ".mat")
                           for i in img_names]

        train_img, train_label, test_img, test_label = split_data(
            self.img_list, self.label_list, ratio=ratio)

        if train:
            self.img_list = train_img
            self.label_list = train_label
        else:
            self.img_list = test_img
            self.label_list = test_label

    def __getitem__(self, idx):
        img = Image.open(self.img_list[idx]).convert('RGB')

        label = sio.loadmat(self.label_list[idx])['inst_map']

        inst_map = np.array(label)
        # 把 图片中的细胞分割出来 全部换为1 背景为0
        inst_map[inst_map > 0] = 1

        # 这里转回PIL的原因是,transforms中是对PIL数据进行处理
        # labelp = Image.fromarray(inst_map)

        if self.transforms is not None:
            for transform in self.transforms:
                img, inst_map = transform(img, inst_map)
        
        # 转换成numpy数组,进行后续处理
        img = np.array(img)
        img = np.transpose(img, (2, 0, 1))
        inst_map = np.array(inst_map)

        return img, inst_map

    def __len__(self):
        return len(self.img_list)

if __name__ == '__main__':

    transform = [
        RandomRotation((-10, 10)),
        RandomHorizontalFlip(),
        RandomScale((0.8, 1.2)),
        Resize((512,512))
    ]

    dataset = LizardDataset(root="/home/aistudio/",
                            train=True, transforms=transform)
    img, mask = dataset[0]
    print(img.shape)
    print(mask.shape)
    plt.subplot(121)
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.subplot(122)
    plt.imshow(mask.squeeze(), cmap="gray")
    plt.show()

    print(len(dataset))


(3, 512, 512)
(512, 512)


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_max = np.asscalar(a_max.astype(scaled_dtype))

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UNco0aDl-1687177330727)(main_files/main_6_2.png)]

190
train_dataset = LizardDataset(root="/home/aistudio/",train=True, transforms=transform)
val_dataset = LizardDataset(root="/home/aistudio/",train=False, transforms=transform)

train_loader = paddle.io.DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = paddle.io.DataLoader(val_dataset, batch_size=8//4, shuffle=True)

print(len(train_loader))
print(len(val_loader))
24
24

2 网络搭建

这里采用Unet实现网络搭建

from paddle import nn

class Encoder(nn.Layer):#下采样:两层卷积,两层归一化,最后池化。
    def __init__(self, num_channels, num_filters):
        super(Encoder,self).__init__()#继承父类的初始化
        self.conv1 = nn.Conv2D(in_channels=num_channels,
                              out_channels=num_filters,
                              kernel_size=3,#3x3卷积核,步长为1,填充为1,不改变图片尺寸[H W]
                              stride=1,
                              padding=1)
        self.bn1   = nn.BatchNorm(num_filters,act="relu")#归一化,并使用了激活函数
        
        self.conv2 = nn.Conv2D(in_channels=num_filters,
                              out_channels=num_filters,
                              kernel_size=3,
                              stride=1,
                              padding=1)
        self.bn2   = nn.BatchNorm(num_filters,act="relu")
        
        self.pool  = nn.MaxPool2D(kernel_size=2,stride=2,padding="SAME")#池化层,图片尺寸减半[H/2 W/2]
        
    def forward(self,inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x_conv = x           
        x_pool = self.pool(x)
        return x_conv, x_pool
    
    
class Decoder(nn.Layer):#上采样:一层反卷积,两层卷积层,两层归一化
    def __init__(self, num_channels, num_filters):
        super(Decoder,self).__init__()
        self.up = nn.Conv2DTranspose(in_channels=num_channels,
                                    out_channels=num_filters,
                                    kernel_size=2,
                                    stride=2,
                                    padding=0)#图片尺寸变大一倍[2*H 2*W]

        self.conv1 = nn.Conv2D(in_channels=num_filters*2,
                              out_channels=num_filters,
                              kernel_size=3,
                              stride=1,
                              padding=1)
        self.bn1   = nn.BatchNorm(num_filters,act="relu")
        
        self.conv2 = nn.Conv2D(in_channels=num_filters,
                              out_channels=num_filters,
                              kernel_size=3,
                              stride=1,
                              padding=1)
        self.bn2   = nn.BatchNorm(num_filters,act="relu")
        
    def forward(self,input_conv,input_pool):
        x = self.up(input_pool)
        h_diff = (input_conv.shape[2]-x.shape[2])
        w_diff = (input_conv.shape[3]-x.shape[3])
        pad = nn.Pad2D(padding=[h_diff//2, h_diff-h_diff//2, w_diff//2, w_diff-w_diff//2])
        x = pad(x)                                #以下采样保存的feature map为基准,填充上采样的feature map尺寸
        x = paddle.concat(x=[input_conv,x],axis=1)#考虑上下文信息,in_channels扩大两倍
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        return x
    
class UNet(nn.Layer):
    def __init__(self,num_classes=2): # num_classes=2 背景 前景
        super(UNet,self).__init__()
        self.down1 = Encoder(num_channels=  3, num_filters=64) #下采样
        self.down2 = Encoder(num_channels= 64, num_filters=128)
        self.down3 = Encoder(num_channels=128, num_filters=256)
        self.down4 = Encoder(num_channels=256, num_filters=512)
        
        self.mid_conv1 = nn.Conv2D(512,1024,1)                 #中间层
        self.mid_bn1   = nn.BatchNorm(1024,act="relu")
        self.mid_conv2 = nn.Conv2D(1024,1024,1)
        self.mid_bn2   = nn.BatchNorm(1024,act="relu")

        self.up4 = Decoder(1024,512)                           #上采样
        self.up3 = Decoder(512,256)
        self.up2 = Decoder(256,128)
        self.up1 = Decoder(128,64)
        
        self.last_conv = nn.Conv2D(64,num_classes,1)           #1x1卷积,softmax做分类
        
    def forward(self,inputs):
        x1, x = self.down1(inputs)
        x2, x = self.down2(x)
        x3, x = self.down3(x)
        x4, x = self.down4(x)
        
        x = self.mid_conv1(x)
        x = self.mid_bn1(x)
        x = self.mid_conv2(x)
        x = self.mid_bn2(x)
        
        x = self.up4(x4, x)
        x = self.up3(x3, x)
        x = self.up2(x2, x)
        x = self.up1(x1, x)
        
        x = self.last_conv(x)
        
        return x
paddle.summary(UNet(), (1, 3, 512, 512))
W0602 19:13:10.241631  1647 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W0602 19:13:10.245633  1647 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.


---------------------------------------------------------------------------------------------------------------------
  Layer (type)                  Input Shape                              Output Shape                   Param #    
=====================================================================================================================
    Conv2D-1                 [[1, 3, 512, 512]]                       [1, 64, 512, 512]                  1,792     
   BatchNorm-1              [[1, 64, 512, 512]]                       [1, 64, 512, 512]                   256      
    Conv2D-2                [[1, 64, 512, 512]]                       [1, 64, 512, 512]                 36,928     
   BatchNorm-2              [[1, 64, 512, 512]]                       [1, 64, 512, 512]                   256      
   MaxPool2D-1              [[1, 64, 512, 512]]                       [1, 64, 256, 256]                    0       
    Encoder-1                [[1, 3, 512, 512]]             [[1, 64, 512, 512], [1, 64, 256, 256]]         0       
    Conv2D-3                [[1, 64, 256, 256]]                       [1, 128, 256, 256]                73,856     
   BatchNorm-3              [[1, 128, 256, 256]]                      [1, 128, 256, 256]                  512      
    Conv2D-4                [[1, 128, 256, 256]]                      [1, 128, 256, 256]                147,584    
   BatchNorm-4              [[1, 128, 256, 256]]                      [1, 128, 256, 256]                  512      
   MaxPool2D-2              [[1, 128, 256, 256]]                      [1, 128, 128, 128]                   0       
    Encoder-2               [[1, 64, 256, 256]]            [[1, 128, 256, 256], [1, 128, 128, 128]]        0       
    Conv2D-5                [[1, 128, 128, 128]]                      [1, 256, 128, 128]                295,168    
   BatchNorm-5              [[1, 256, 128, 128]]                      [1, 256, 128, 128]                 1,024     
    Conv2D-6                [[1, 256, 128, 128]]                      [1, 256, 128, 128]                590,080    
   BatchNorm-6              [[1, 256, 128, 128]]                      [1, 256, 128, 128]                 1,024     
   MaxPool2D-3              [[1, 256, 128, 128]]                       [1, 256, 64, 64]                    0       
    Encoder-3               [[1, 128, 128, 128]]            [[1, 256, 128, 128], [1, 256, 64, 64]]         0       
    Conv2D-7                 [[1, 256, 64, 64]]                        [1, 512, 64, 64]                1,180,160   
   BatchNorm-7               [[1, 512, 64, 64]]                        [1, 512, 64, 64]                  2,048     
    Conv2D-8                 [[1, 512, 64, 64]]                        [1, 512, 64, 64]                2,359,808   
   BatchNorm-8               [[1, 512, 64, 64]]                        [1, 512, 64, 64]                  2,048     
   MaxPool2D-4               [[1, 512, 64, 64]]                        [1, 512, 32, 32]                    0       
    Encoder-4                [[1, 256, 64, 64]]              [[1, 512, 64, 64], [1, 512, 32, 32]]          0       
    Conv2D-9                 [[1, 512, 32, 32]]                       [1, 1024, 32, 32]                 525,312    
   BatchNorm-9              [[1, 1024, 32, 32]]                       [1, 1024, 32, 32]                  4,096     
    Conv2D-10               [[1, 1024, 32, 32]]                       [1, 1024, 32, 32]                1,049,600   
  BatchNorm-10              [[1, 1024, 32, 32]]                       [1, 1024, 32, 32]                  4,096     
Conv2DTranspose-1           [[1, 1024, 32, 32]]                        [1, 512, 64, 64]                2,097,664   
    Conv2D-11               [[1, 1024, 64, 64]]                        [1, 512, 64, 64]                4,719,104   
  BatchNorm-11               [[1, 512, 64, 64]]                        [1, 512, 64, 64]                  2,048     
    Conv2D-12                [[1, 512, 64, 64]]                        [1, 512, 64, 64]                2,359,808   
  BatchNorm-12               [[1, 512, 64, 64]]                        [1, 512, 64, 64]                  2,048     
    Decoder-1      [[1, 512, 64, 64], [1, 1024, 32, 32]]               [1, 512, 64, 64]                    0       
Conv2DTranspose-2            [[1, 512, 64, 64]]                       [1, 256, 128, 128]                524,544    
    Conv2D-13               [[1, 512, 128, 128]]                      [1, 256, 128, 128]               1,179,904   
  BatchNorm-13              [[1, 256, 128, 128]]                      [1, 256, 128, 128]                 1,024     
    Conv2D-14               [[1, 256, 128, 128]]                      [1, 256, 128, 128]                590,080    
  BatchNorm-14              [[1, 256, 128, 128]]                      [1, 256, 128, 128]                 1,024     
    Decoder-2      [[1, 256, 128, 128], [1, 512, 64, 64]]             [1, 256, 128, 128]                   0       
Conv2DTranspose-3           [[1, 256, 128, 128]]                      [1, 128, 256, 256]                131,200    
    Conv2D-15               [[1, 256, 256, 256]]                      [1, 128, 256, 256]                295,040    
  BatchNorm-15              [[1, 128, 256, 256]]                      [1, 128, 256, 256]                  512      
    Conv2D-16               [[1, 128, 256, 256]]                      [1, 128, 256, 256]                147,584    
  BatchNorm-16              [[1, 128, 256, 256]]                      [1, 128, 256, 256]                  512      
    Decoder-3     [[1, 128, 256, 256], [1, 256, 128, 128]]            [1, 128, 256, 256]                   0       
Conv2DTranspose-4           [[1, 128, 256, 256]]                      [1, 64, 512, 512]                 32,832     
    Conv2D-17               [[1, 128, 512, 512]]                      [1, 64, 512, 512]                 73,792     
  BatchNorm-17              [[1, 64, 512, 512]]                       [1, 64, 512, 512]                   256      
    Conv2D-18               [[1, 64, 512, 512]]                       [1, 64, 512, 512]                 36,928     
  BatchNorm-18              [[1, 64, 512, 512]]                       [1, 64, 512, 512]                   256      
    Decoder-4     [[1, 64, 512, 512], [1, 128, 256, 256]]             [1, 64, 512, 512]                    0       
    Conv2D-19               [[1, 64, 512, 512]]                        [1, 2, 512, 512]                   130      
=====================================================================================================================
Total params: 18,472,450
Trainable params: 18,460,674
Non-trainable params: 11,776
---------------------------------------------------------------------------------------------------------------------
Input size (MB): 3.00
Forward/backward pass size (MB): 2796.00
Params size (MB): 70.47
Estimated Total Size (MB): 2869.47
---------------------------------------------------------------------------------------------------------------------






{'total_params': 18472450, 'trainable_params': 18460674}

3 开始训练

# 自行调参
!python train.py
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
start set parameters...
start load dataset...
start train...
W0602 19:13:15.872524  1810 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W0602 19:13:15.876902  1810 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
2023-06-02 19:13:58 [TRAIN] epoch: [0/50], loss: 0.07847500592470169
2023-06-02 19:14:08 [EVAL] epoch: [0/50], mIoU is: 0.3811, Dice is: 0.4938, loss is: 0.0734
2023-06-02 19:14:08 [EVAL] Class IoU: [0.6523 0.1099]
2023-06-02 19:15:00 [TRAIN] epoch: [1/50], loss: 0.03849150240421295
2023-06-02 19:15:13 [EVAL] epoch: [1/50], mIoU is: 0.5515, Dice is: 0.6821, loss is: 0.0445
2023-06-02 19:15:13 [EVAL] Class IoU: [0.7808 0.3222]
^C

5515, Dice is: 0.6821, loss is: 0.0445
2023-06-02 19:15:13 [EVAL] Class IoU: [0.7808 0.3222]
^C

项目总结

这是一个从无到有的实例分割的项目。

项目使用自己写的数据预处理、模型、和训练脚本实现

个人总结

iterhui

我在AI Studio上获得青铜等级,点亮0个徽章,来互关呀~ https://aistudio.baidu.com/aistudio/personalcenter/thirdview/643467

此文章为搬运
原项目链接

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值