Cellpose论文和项目结构初探

注:本篇博客是笔者在之前做细胞分割工作时学习参考cellpose论文及代码所作笔记,内容不够深入,但可以带领读者大致了解cellpose论文的工作亮点以及项目结构,和代码中部分接口的实现,适合小白入门。如有不准确的地方还请在评论区或私信交流~

目录

Cellpose论文学习

技术背景(需求来源)

模型设计

预测过程

网络设计

训练集介绍

数据来源

可视化

​Cellpose代码复现

项目代码解读


论文地址:cellpose论文(点击download下载)

项目源码地址:cellpose项目源码

Cellpose论文学习

技术背景(需求来源)

现有的方法专对训练集庞大的图像有用,此文提出的是一个通用性更强的的、基于深度学习细胞图像分割方法——CellPose。Cellpose的训练集包含了超过70000种被分割的实体。

遇到的问题:细胞紧密连接时较难分离,而现有方法会在灵活性和自动化上make trade-off. 现有的方法:手工标注数据->用户自定义pipeline->全自动化,自动更新参数。缺陷:数据集受限

工作基于Data Science Bowl的approach获得并手动分割大量带荧光标记的数据集(以前的方法在此数据集表现不佳,适合开发一个具有更好表达能力的新模型,即cellpose)

模型设计

传统方法的思路:基于分水岭,使用图像灰度值绘制拓扑图,这在分割对象具有从中心平滑衰减的强度剖面时是能工作的,但由于荧光标记物排除了核以及沿细胞边界的突起不均匀分布,许多类型的细胞形成多个强度盆地,cellpose致力于解决这个问题

cellpose的思路:基于模拟扩散生成拓扑图(使用mask),训练一个神经网络来确定拓扑图的水平和垂直梯度和二值图像(此二值图像可以确定像素是否在ROI中)

预测过程

预测时,预测的梯度向量场被用来构建一个具有不动点的动力系统,其吸引域代表预测的mask。无形中,每个像素'跟踪梯度'到它们的最终固定点。所有收敛到相同固定点的像素被分配到相同的mask

如何跟踪梯度?

  • 首先,预测水平和垂直梯度后,形成向量场并通过梯度追踪这个场,通过这样的方法可以把一个给定的细胞的所有像素路由到其中心

  • 随后,将收敛到同一点的像素分组,恢复单个细胞及其精确形状

训练一个神经网络来预测水平和垂直梯度,以及一个像素是否属于任何单元格。将三个预测图组合成一个梯度向量场👇

网络设计

U-Net:用于预测空间梯度,非常对称的一个网络,具有相同大小的层之间的跳跃连接和全局跳跃连接,从最低分辨率计算的图像样式到所有连续的计算

 

cellpose工作亮点:

  1. 在上采样阶段,U-Net会“混合”(通过特征级联)下采样通道得到的卷积图。通过直接求和取代了特征级联,减少了参数数量;

  2. 将U-Net的标准构建块换成残差块,性能更好,并将网络深度增加了一倍;

  3. 在最小的卷积图上使用全局平均池化来获得图像的"style"表示

以上工作均显著提高了性能。

训练集介绍

数据来源

  • 互联网搜索关键词得到("细胞质" , "细胞显微镜" , "荧光细胞"等),该数据集主要由荧光标记的蛋白质组成,这些蛋白质定位于细胞质中,在一个单独的通道中

  • 还包括来自明场显微镜( n = 50)的细胞图像和膜标记细胞( n = 58)的图像。

  • 包括一组来自其他类型显微镜( n = 86)的小幅图像,以及一组包含大量重复物体如水果、岩石和水母( n = 98)的小幅非显微镜图像

目标:将这些图像包含在训练集中将使网络能够更广泛和更健壮地泛化。

可视化

通过将t分布随机近邻嵌入( t-SNE )应用于神经网络学习到的图像风格来可视化该数据集的结构,并允许用户贡献自己的数据集。

 

在608幅图像的不同数据集中,100幅图像的子集被预分割为细胞图像库的一部分。这些双通道图像(细胞质和细胞核)来自单一的实验准备,包含形状复杂的细胞。基于其一般性,此数据集很适合作表达能力的benchmark。

下面是cellpose对36张测试图像的分割结果:

 与Stardist和Mask rcnn做对比:

Cellpose代码复现

将源码git clone到本地,载入pycharm:

项目代码解读

  1. ./cellpose/models.py

主要类:

  • Cellpose():用于combine Sizemodel和CellposeModel
  • CellposeModel(UnetModel):即基于UnetModel的Cellpose模型,用于处理输入图片
  • SizeModel:一种线性回归模型,在输入cellpose模型前确定待重新调节图片中物体的尺寸

   2. /cellpose/metrics.py

import的库: 

import numpy as np
from . import utils, dynamics
from numba import jit
from scipy.optimize import linear_sum_assignment
from scipy.ndimage import convolve, mean

np.zeros():返回来一个给定形状和类型的用0填充的数组;

zeros(shape, dtype=float, order=‘C’) shape:形状 dtype:数据类型,可选参数,默认numpy.float64 order:可选参数,c代表与c语言类似,行优先;F代表列优先

mask_iou():返回最匹配的mask

def mask_ious(masks_true, masks_pred):
    """ return best-matched masks """
    iou = _intersection_over_union(masks_true, masks_pred)[1:,1:]
    n_min = min(iou.shape[0], iou.shape[1])    # 得到面积最小的iou
    costs = -(iou >= 0.5).astype(float) - iou / (2*n_min)    # 代价函数
    true_ind, pred_ind = linear_sum_assignment(costs)
    iout = np.zeros(masks_true.max())    # 全0的numpy.float64数组,大小为真实mask数量的最大值?
    iout[true_ind] = iou[true_ind,pred_ind]
    preds = np.zeros(masks_true.max(), 'int')
    preds[true_ind] = pred_ind+1
    return iout, preds

boundary_scores():计算boundary_scores(precision/recall/fscore)

Precision = \frac{TP}{TP+FP}\qquad Recall=\frac{TP}{TP+FN}\qquad F-score=\frac{2*Precision*Recall}{Precision+Recall}

def boundary_scores(masks_true, masks_pred, scales):
    """ boundary precision / recall / Fscore """
    diams = [utils.diameters(lbl)[0] for lbl in masks_true]
    precision = np.zeros((len(scales), len(masks_true)))
    recall = np.zeros((len(scales), len(masks_true)))
    fscore = np.zeros((len(scales), len(masks_true)))
    for j, scale in enumerate(scales):
        for n in range(len(masks_true)):
            diam = max(1, scale * diams[n])
            rs, ys, xs = utils.circleMask([int(np.ceil(diam)), int(np.ceil(diam))])
            filt = (rs <= diam).astype(np.float32)
            otrue = utils.masks_to_outlines(masks_true[n])
            otrue = convolve(otrue, filt)
            opred = utils.masks_to_outlines(masks_pred[n])
            opred = convolve(opred, filt)
            tp = np.logical_and(otrue==1, opred==1).sum()
            fp = np.logical_and(otrue==0, opred==1).sum()
            fn = np.logical_and(otrue==1, opred==0).sum()
            precision[j,n] = tp / (tp + fp)
            recall[j,n] = tp / (tp + fn)
        fscore[j] = 2 * precision[j] * recall[j] / (precision[j] + recall[j])
    return precision, recall, fscore

aggregated_jaccard_index():计算aji值。

AJI = \frac {maskIntersection}{maskUnion} 

def aggregated_jaccard_index(masks_true, masks_pred):
    """ AJI = intersection of all matched masks / union of all masks 
    
    Parameters
    ------------
    
    masks_true: list of ND-arrays (int) or ND-array (int) 
        where 0=NO masks; 1,2... are mask labels
    masks_pred: list of ND-arrays (int) or ND-array (int) 
        ND-array (int) where 0=NO masks; 1,2... are mask labels

    Returns
    ------------

    aji : aggregated jaccard index for each set of masks

    """

aji = np.zeros(len(masks_true))
for n in range(len(masks_true)):
    iout, preds = mask_ious(masks_true[n], masks_pred[n])
    inds = np.arange(0, masks_true[n].max(), 1, int)
    overlap = _label_overlap(masks_true[n], masks_pred[n])
    union = np.logical_or(masks_true[n]>0, masks_pred[n]>0).sum()
    overlap = overlap[inds[preds>0]+1, preds[preds>0].astype(int)]
    aji[n] = overlap.sum() / union
return aji 

 average_precision():计算ap

def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]):
    """ average precision estimation: AP = TP / (TP + FP + FN)"""
    not_list = False
if not isinstance(masks_true, list):
    masks_true = [masks_true]
    masks_pred = [masks_pred]
    not_list = True
if not isinstance(threshold, list) and not isinstance(threshold, np.ndarray):
    threshold = [threshold]

if len(masks_true) != len(masks_pred):
    raise ValueError('metrics.average_precision requires len(masks_true)==len(masks_pred)')

ap  = np.zeros((len(masks_true), len(threshold)), np.float32)
tp  = np.zeros((len(masks_true), len(threshold)), np.float32)
fp  = np.zeros((len(masks_true), len(threshold)), np.float32)
fn  = np.zeros((len(masks_true), len(threshold)), np.float32)
n_true = np.array(list(map(np.max, masks_true)))
n_pred = np.array(list(map(np.max, masks_pred)))

for n in range(len(masks_true)):
    #_,mt = np.reshape(np.unique(masks_true[n], return_index=True), masks_pred[n].shape)
    if n_pred[n] > 0:
        iou = _intersection_over_union(masks_true[n], masks_pred[n])[1:, 1:]
        for k,th in enumerate(threshold):
            tp[n,k] = _true_positive(iou, th)
    fp[n] = n_pred[n] - tp[n]
    fn[n] = n_true[n] - tp[n]
    ap[n] = tp[n] / (tp[n] + fp[n] + fn[n])  
    
if not_list:
    ap, tp, fp, fn = ap[0], tp[0], fp[0], fn[0]
return ap, tp, fp, fn

 后面的_lable_overlap、_intersection_over_union、_true_positive都不涉及网络结构等获取mask的过程,在此不作赘述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值