3D Self-Supervised Methods for Medical Imaging
论文信息
- Paper: [NeurIPS2020] 3D Self-Supervised Methods for Medical Imaging
- Link: https://papers.nips.cc/paper/2020/file/d2dc6368837861b42020ee72b0896182-Paper.pdf
- Code: https://github.com/HealthML/self-supervised-3d-tasks
背景梳理
深度学习在医学影像分析上的应用需要依赖于大量的有标注数据,而医学影像数据多为3D图像(CT,MRI),标注需要专业的医生花费大量的时间和精力。近期许多自监督的方法的提出,通过数据本身的结构或者特性,设计自监督任务,利用生成的伪标签进行训练来学习图像表征。经过自监督学习预训练得到的模型,在目标任务上只需要少量的标注数据对模型进行微调即可获得较好的效果。近期的多数关于自监督的工作都是基于2D图像设计的,无法直接在3D医疗图像上使用,今天要介绍的文章将五种自监督任务扩展到3D图像上,相比于直接训练(train from scratch)和2D的自监督学习方法,在3D医疗图像分割的任务上取得了较好的效果。
主要贡献
方法设计
基于现有的2D自监督方法,提出了五种3D自监督方法,包括:3D Contrastive Predictive Coding, 3D Rotation prediction, 3D Jigsaw puzzles, Relative 3D patch location, and 3D Exemplar networks。
实验结果
本文在三个不同的任务上,从数据有效性,性能和收敛速度三个角度进行了分析,比较了提出的3D自监督方法。
方法
3D Contrastive Predictive Coding(3D-CPC)
对比预测编码方法[1]由DeepMind首次提出,核心的思想也是对比学习(Contrastive Learning),通过编码序列数据中数据之间的关联信息来学习表征,具体的做法是预测未来(下一个或者相邻)样本的潜在空间(latent space)。
这篇文章将CPC扩展到了3D,对每个输入数据,如下图所示,crop相同大小的重叠的3D patch
x
i
,
j
,
k
x_{i,j,k}
xi,j,k,通过编码器
g
e
n
c
g_{enc}
genc得到在隐空间的表达
z
u
,
v
,
w
=
g
e
n
c
(
x
i
,
j
,
k
)
z_{u,v,w}=g_{enc}(x_{i,j,k})
zu,v,w=genc(xi,j,k)。然后通过上下文网络(context network)
g
c
x
t
g_{cxt}
gcxt 进行汇总得到上下文向量(context vector)
c
i
,
j
,
k
=
g
c
x
t
(
{
z
u
,
v
,
w
}
u
≤
i
,
v
,
w
)
c_{i,j,k}=g_{cxt}(\{z_{u,v,w}\}_{u\leq i,v,w})
ci,j,k=gcxt({zu,v,w}u≤i,v,w)。由于向量
c
i
,
j
,
k
c_{i,j,k}
ci,j,k捕捉了相对于
x
i
,
j
,
k
x_{i,j,k}
xi,j,k上下文的更高层的内容,因此可以预测下一个patch的隐空间的表达
z
i
+
l
,
j
,
k
,
l
≥
0
z_{i+l,j,k},l\geq0
zi+l,j,k,l≥0。
z
i
+
l
,
j
,
k
z_{i+l,j,k}
zi+l,j,k作为positive的样本,从输入图像其他位置随机crop的patch
z
n
z_{n}
zn作为negative的样本,通过CPC损失函数进行优化:
KaTeX parse error: Undefined control sequence: \bbox at position 2: \̲b̲b̲o̲x̲[white, 3px]{ \…
Relative 3D patch location (3D-RPL)
RPL是预测3D patch位置的任务,利用图像中的空间上下文信息作为监督。
如下图所示,将输入图像划分为由
N
N
N个非重叠patch组成的3D网格
{
x
i
}
i
∈
{
1
,
2
,
.
.
.
,
N
}
\{x_i\}_{i \in \{1,2,...,N\}}
{xi}i∈{1,2,...,N}。以网格中心的patch作为参照,从周围的patch中随机选取一个query patch
x
q
x_q
xq,3D-RPL的任务即为预测query patch
x
q
x_q
xq 的位置
y
q
y_q
yq。将问题建模为一个
N
−
1
N-1
N−1的分类问题,损失函数定义为:
KaTeX parse error: Undefined control sequence: \bbox at position 2: \̲b̲b̲o̲x̲[white, 3px]{\m…
3D Jigsaw puzzle Solving (3D-Jig)
Jigsaw就是将输入图像打乱后进行拼图任务。
如下图所示,将输入图像划分成一个
n
×
n
×
n
n\times n\times n
n×n×n的3D 网格,然后从一组预设的排列组合中随机选择一种,进行打乱操作。大小为
P
P
P的排列组合是从
n
3
!
n^3!
n3!中可能的排列组合中选择出来的,每个排列组合被分配一个索引
y
p
∈
{
1
,
2
,
.
.
.
.
,
P
}
y_p\in\{1,2,....,P\}
yp∈{1,2,....,P}。因此这个问题就被建模为一个P-way分类任务,最小化交叉熵损失函数
L
J
i
g
(
y
p
k
,
y
^
p
k
)
\mathcal{L}_{Jig}(y_{p}^{k},\hat{y}_{p}^{k})
LJig(ypk,y^pk),其中
k
∈
{
1
,
.
.
.
,
K
}
k\in \{1,...,K\}
k∈{1,...,K} 是从提取的K个puzzles中选取的一个任意的3D puzzle。
3D Rotation prediction (3D-Rot)
Rotation 通过简单地预测输入图像的旋转角度来学习视觉表征。
如下图所示,输入图像在
R
R
R个考虑的度数中随机旋转一个度数
r
∈
{
1
,
.
.
.
,
R
}
r\in \{ 1,...,R\}
r∈{1,...,R}。本文中作者使用的度数为90度的倍数,沿着3D坐标系
(
x
,
y
,
z
)
(x,y,z)
(x,y,z) 的每个轴。每个轴有4种可能的旋转,总共12种可能的旋转。沿着3个轴旋转0度的结果相同,所以共有10种旋转方式。这种情况下问题被建模为一个10分类问题,最小化
L
R
o
t
(
r
k
,
r
^
k
)
\mathcal{L}_{Rot}(r_{k},\hat{r}_{k})
LRot(rk,r^k),其中
k
∈
{
1
,
.
.
.
,
K
}
k\in \{1,...,K\}
k∈{1,...,K}。
3D Exemplar networks (3D-Exe)
定义训练集
X
=
{
x
1
,
.
.
.
,
x
N
}
X=\{x_1,...,x_N\}
X={x1,...,xN},和包含
K
K
K种图像变换的集合
T
=
{
T
1
,
.
.
.
,
T
K
}
\mathcal{T}=\{T_1,...,T_K\}
T={T1,...,TK},通过对每个训练样本进行变换可以得到新的surrogate类别
S
x
i
=
T
x
i
=
{
T
x
i
|
T
∈
T
}
{S_{x}}_{i} = \mathcal{T} x_i = \{ Tx_i | T \in \mathcal{T} \}
Sxi=Txi={Txi|T∈T}。最初的Exemplar网络是对新的类别进行分类的任务,然而随着数据集的增加,类别数也会很大,分类任务会变得很难。
本文提出的3D-Exe参考对比学习的思想,通过triplet损失函数进行学习,如下图所示,假设
x
i
x_i
xi是一个训练样本,
z
i
z_i
zi是其对应的embedding vector,
x
i
+
x_i^{+}
xi+是
x
i
x_i
xi的变换(正样本),
x
i
−
x_i^{-}
xi−是数据集中的另一个样本,损失函数定义如下:
KaTeX parse error: Undefined control sequence: \bbox at position 2: \̲b̲b̲o̲x̲[white, 3px]{ \…
其中
D
(
.
)
D(.)
D(.)是
L
2
L_2
L2距离。
实验结果
Brain Tumor Segmentation Results
对于脑部肿瘤分割任务,本文使用UK Biobank的数据集进行自监督任务的训练,该数据集包含约22k 3D 脑部MRI数据,使用BraTS 2018的脑肿瘤分割数据集(285个训练数据,66个验证数据)来评估有效性。与2D的自监督方法相比较,并且还对比了使用不同比例训练数据的结果,如下图所示。
从图中可以看出以下几点:
- 2D自监督任务的结果与3D网络直接训练差不多,可见对于该任务3D模型的必要性(这里应该再有一个2D-From scratch)的结果。
- 使用3D自监督任务进行预训练的结果要明显好于直接训练,尤其是对于标注数据量很少的情况以及比较难分割的部位。
Pancreas Tumor Segmentation Results
对于胰腺肿瘤分割任务,本文使用medical decathlon benchmarks的胰腺数据集,包含420个CT扫描,划分为训练集和测试集。先使用训练集的数据进行自监督任务的训练,再对目标任务进行微调,结果如下图所示。
Diabetic Retinopathy Results
作者还在糖尿病视网膜病变2019年Kaggle挑战赛的2D数据集上进行训练,该数据集包含5590张2D眼底图像,将2D自监督任务与2D baseline对比,结果如上图所示。
分析与总结
本文将常见的自监督方法扩展到3D,应用于医疗影像分析,可以作为后续3D医疗影像自监督学习/预训练模型研究的baseline,从结果来看自监督学习可以有效地利用大量的无标注数据,尤其是在标注样本很少的情况下会有很好的performance。
但从方法角度看,本文只是将已有的2D自然图像上的自监督学习的任务进行一个简单的扩展,没有根据医学图像和目标任务的特点进行设计,并不是每个自监督任务都能学习到有用的特征表达,甚至有的自监督任务预训练后会降低性能。
从实验设计上看,在Brats2018上引入了相当多额外的数据用于预训练,但结果相较于SOTA提升并不多,如果有只使用Brats2018的training set进行预训练的实验结果会怎样?对于一个预训练模型,数据量和自监督方法对效果的影响仍有待研究。
参考文献
[1] Representation Learning with Contrastive Predictive Coding (ECCV2020)
Link: https://arxiv.org/pdf/1807.03748.pdf
Code: https://github.com/jefflai108/Contrastive-Predictive-Coding-PyTorch