注:本篇博客是笔者在之前做细胞分割工作时学习参考cellpose论文及代码所作笔记,内容不够深入,但可以带领读者大致了解cellpose论文的工作亮点以及项目结构,和代码中部分接口的实现,适合小白入门。如有不准确的地方还请在评论区或私信交流~
目录
项目源码地址:cellpose项目源码
Cellpose论文学习
技术背景(需求来源)
现有的方法专对训练集庞大的图像有用,此文提出的是一个通用性更强的的、基于深度学习细胞图像分割方法——CellPose。Cellpose的训练集包含了超过70000种被分割的实体。
遇到的问题:细胞紧密连接时较难分离,而现有方法会在灵活性和自动化上make trade-off. 现有的方法:手工标注数据->用户自定义pipeline->全自动化,自动更新参数。缺陷:数据集受限
工作基于Data Science Bowl的approach获得并手动分割大量带荧光标记的数据集(以前的方法在此数据集表现不佳,适合开发一个具有更好表达能力的新模型,即cellpose)
模型设计
传统方法的思路:基于分水岭,使用图像灰度值绘制拓扑图,这在分割对象具有从中心平滑衰减的强度剖面时是能工作的,但由于荧光标记物排除了核以及沿细胞边界的突起不均匀分布,许多类型的细胞形成多个强度盆地,cellpose致力于解决这个问题。
cellpose的思路:基于模拟扩散生成拓扑图(使用mask),训练一个神经网络来确定拓扑图的水平和垂直梯度和二值图像(此二值图像可以确定像素是否在ROI中)。
预测过程
预测时,预测的梯度向量场被用来构建一个具有不动点的动力系统,其吸引域代表预测的mask。无形中,每个像素'跟踪梯度'到它们的最终固定点。所有收敛到相同固定点的像素被分配到相同的mask
如何跟踪梯度?
-
首先,预测水平和垂直梯度后,形成向量场并通过梯度追踪这个场,通过这样的方法可以把一个给定的细胞的所有像素路由到其中心。
-
随后,将收敛到同一点的像素分组,恢复单个细胞及其精确形状
训练一个神经网络来预测水平和垂直梯度,以及一个像素是否属于任何单元格。将三个预测图组合成一个梯度向量场👇
网络设计
U-Net:用于预测空间梯度,非常对称的一个网络,具有相同大小的层之间的跳跃连接和全局跳跃连接,从最低分辨率计算的图像样式到所有连续的计算
cellpose工作亮点:
-
在上采样阶段,U-Net会“混合”(通过特征级联)下采样通道得到的卷积图。通过直接求和取代了特征级联,减少了参数数量;
-
将U-Net的标准构建块换成残差块,性能更好,并将网络深度增加了一倍;
-
在最小的卷积图上使用全局平均池化来获得图像的"style"表示
以上工作均显著提高了性能。
训练集介绍
数据来源
-
互联网搜索关键词得到("细胞质" , "细胞显微镜" , "荧光细胞"等),该数据集主要由荧光标记的蛋白质组成,这些蛋白质定位于细胞质中,在一个单独的通道中
-
还包括来自明场显微镜( n = 50)的细胞图像和膜标记细胞( n = 58)的图像。
-
包括一组来自其他类型显微镜( n = 86)的小幅图像,以及一组包含大量重复物体如水果、岩石和水母( n = 98)的小幅非显微镜图像
目标:将这些图像包含在训练集中将使网络能够更广泛和更健壮地泛化。
可视化
通过将t分布随机近邻嵌入( t-SNE )应用于神经网络学习到的图像风格来可视化该数据集的结构,并允许用户贡献自己的数据集。
在608幅图像的不同数据集中,100幅图像的子集被预分割为细胞图像库的一部分。这些双通道图像(细胞质和细胞核)来自单一的实验准备,包含形状复杂的细胞。基于其一般性,此数据集很适合作表达能力的benchmark。
下面是cellpose对36张测试图像的分割结果:
与Stardist和Mask rcnn做对比:
Cellpose代码复现
将源码git clone到本地,载入pycharm:
项目代码解读
-
./cellpose/models.py
主要类:
- Cellpose():用于combine Sizemodel和CellposeModel
- CellposeModel(UnetModel):即基于UnetModel的Cellpose模型,用于处理输入图片
- SizeModel:一种线性回归模型,在输入cellpose模型前确定待重新调节图片中物体的尺寸
2. /cellpose/metrics.py
import的库:
import numpy as np
from . import utils, dynamics
from numba import jit
from scipy.optimize import linear_sum_assignment
from scipy.ndimage import convolve, mean
np.zeros():返回来一个给定形状和类型的用0填充的数组;
zeros(shape, dtype=float, order=‘C’) shape:形状 dtype:数据类型,可选参数,默认numpy.float64 order:可选参数,c代表与c语言类似,行优先;F代表列优先
mask_iou():返回最匹配的mask
def mask_ious(masks_true, masks_pred):
""" return best-matched masks """
iou = _intersection_over_union(masks_true, masks_pred)[1:,1:]
n_min = min(iou.shape[0], iou.shape[1]) # 得到面积最小的iou
costs = -(iou >= 0.5).astype(float) - iou / (2*n_min) # 代价函数
true_ind, pred_ind = linear_sum_assignment(costs)
iout = np.zeros(masks_true.max()) # 全0的numpy.float64数组,大小为真实mask数量的最大值?
iout[true_ind] = iou[true_ind,pred_ind]
preds = np.zeros(masks_true.max(), 'int')
preds[true_ind] = pred_ind+1
return iout, preds
boundary_scores():计算boundary_scores(precision/recall/fscore)
def boundary_scores(masks_true, masks_pred, scales):
""" boundary precision / recall / Fscore """
diams = [utils.diameters(lbl)[0] for lbl in masks_true]
precision = np.zeros((len(scales), len(masks_true)))
recall = np.zeros((len(scales), len(masks_true)))
fscore = np.zeros((len(scales), len(masks_true)))
for j, scale in enumerate(scales):
for n in range(len(masks_true)):
diam = max(1, scale * diams[n])
rs, ys, xs = utils.circleMask([int(np.ceil(diam)), int(np.ceil(diam))])
filt = (rs <= diam).astype(np.float32)
otrue = utils.masks_to_outlines(masks_true[n])
otrue = convolve(otrue, filt)
opred = utils.masks_to_outlines(masks_pred[n])
opred = convolve(opred, filt)
tp = np.logical_and(otrue==1, opred==1).sum()
fp = np.logical_and(otrue==0, opred==1).sum()
fn = np.logical_and(otrue==1, opred==0).sum()
precision[j,n] = tp / (tp + fp)
recall[j,n] = tp / (tp + fn)
fscore[j] = 2 * precision[j] * recall[j] / (precision[j] + recall[j])
return precision, recall, fscore
aggregated_jaccard_index():计算aji值。
def aggregated_jaccard_index(masks_true, masks_pred):
""" AJI = intersection of all matched masks / union of all masks
Parameters
------------
masks_true: list of ND-arrays (int) or ND-array (int)
where 0=NO masks; 1,2... are mask labels
masks_pred: list of ND-arrays (int) or ND-array (int)
ND-array (int) where 0=NO masks; 1,2... are mask labels
Returns
------------
aji : aggregated jaccard index for each set of masks
"""
aji = np.zeros(len(masks_true))
for n in range(len(masks_true)):
iout, preds = mask_ious(masks_true[n], masks_pred[n])
inds = np.arange(0, masks_true[n].max(), 1, int)
overlap = _label_overlap(masks_true[n], masks_pred[n])
union = np.logical_or(masks_true[n]>0, masks_pred[n]>0).sum()
overlap = overlap[inds[preds>0]+1, preds[preds>0].astype(int)]
aji[n] = overlap.sum() / union
return aji
average_precision():计算ap
def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]):
""" average precision estimation: AP = TP / (TP + FP + FN)"""
not_list = False
if not isinstance(masks_true, list):
masks_true = [masks_true]
masks_pred = [masks_pred]
not_list = True
if not isinstance(threshold, list) and not isinstance(threshold, np.ndarray):
threshold = [threshold]
if len(masks_true) != len(masks_pred):
raise ValueError('metrics.average_precision requires len(masks_true)==len(masks_pred)')
ap = np.zeros((len(masks_true), len(threshold)), np.float32)
tp = np.zeros((len(masks_true), len(threshold)), np.float32)
fp = np.zeros((len(masks_true), len(threshold)), np.float32)
fn = np.zeros((len(masks_true), len(threshold)), np.float32)
n_true = np.array(list(map(np.max, masks_true)))
n_pred = np.array(list(map(np.max, masks_pred)))
for n in range(len(masks_true)):
#_,mt = np.reshape(np.unique(masks_true[n], return_index=True), masks_pred[n].shape)
if n_pred[n] > 0:
iou = _intersection_over_union(masks_true[n], masks_pred[n])[1:, 1:]
for k,th in enumerate(threshold):
tp[n,k] = _true_positive(iou, th)
fp[n] = n_pred[n] - tp[n]
fn[n] = n_true[n] - tp[n]
ap[n] = tp[n] / (tp[n] + fp[n] + fn[n])
if not_list:
ap, tp, fp, fn = ap[0], tp[0], fp[0], fn[0]
return ap, tp, fp, fn
后面的_lable_overlap、_intersection_over_union、_true_positive都不涉及网络结构等获取mask的过程,在此不作赘述