psa-for-prostate代码笔记

psa-for-prostate代码梳理

psa(Pixel-level Semantic Affinity)这篇文章应用到prostate图像当中,用于对prostate图像进行弱监督语义分割。代码主要分为四个部分:

  1. 训练分类网络模型
  2. 生成不同形式的CAMs以供使用
  3. 训练affinity网络模型
  4. 对原CAMs进行fine-tuning

下面对代码的整个结构以及其中的一些细节进行记录和整理。

1.训练分类网络模型

2.生成不同形式的CAMs以供使用

3.训练affinity网络模型

3.1 数据预处理

这一部分的训练数据在data.py文件中由AffDataset类定义,该类继承ImageDataset类。训练数据包括原图片和affinity标签(关于affinity标签的定义参照论文,此处不祥说),其中原图片易取(包括一些裁剪、翻转、归一化等预处理操作)而affinity标签难得。因为affinity标签需要根据之前生成的各种CAMs来生成,具体生成过程见data.pyExtractAffinityLabelInRadius类。生成affinity标签所需要的数据有降采样后尺寸cropsize=cropsize//8,范围半径radius=radius和阈值化后的CAMs标签(label_la,label_ha)。其中,
cropsize是降采样后的尺寸,注意这里的操作都是在8倍降采样(平均池化,步长为8)后的尺寸里操作的;
范围半径radius是指在选取affinity标签对的时候所考虑的点对之间的距离是在半径radius内的;
阈值化后的CAMs标签包括确定的前景,确定的背景和非确定区域。label_la是将背景区域值偏大估计的二维mask,label_ha是将背景区域偏小估计的二维mask,通过这两个mask可以形成前面所提到的阈值化后的CAMs标签,这个阈值化后的mask有以下几个特点:1.是降采样后的尺寸,2.包括确定的背景和确定的前景,3.是单通道的,不同类别用不同像素值表示。然后对affinity标签的提取过程详细说明如下:
首先在给定半径范围选择符合要求的搜索点对,这里假设半径为5,则符合要求的点对有34个;
然后根据将采样后的尺寸结合半径确定之后标签提取过程在(crop_height,cop_width)的尺寸上进行;
接着确定起始点和目的点的坐标,分别用labels_from和labels_to来表示;
最后根据affinity标签的定义用逻辑与或非等操作确定背景正样本、前景正样本和负样本。具体代码如下:

bg_pos_affinity_label = np.logical_and(pos_affinity_label, np.equal(bc_labels_from, 0)).astype(np.float32)
fg_pos_affinity_label = np.logical_and(np.logical_and(pos_affinity_label, np.not_equal(bc_labels_from, 0)), concat_valid_pair).astype(np.float32)
neg_affinity_label = np.logical_and(np.logical_not(pos_affinity_label), concat_valid_pair).astype(np.float32)

3.2 网络模型

使用的backbone网络还是vgg16结构,然后对后三个卷积层的输出进行连接,再进行一个1x1卷积操作得到最后的多通道的affinity特征图,生成的特征图用于后续的affinity以及loss的计算。

3.3 模型训练

训练的前向传播过程先生成上一部分提到的多通道(448)的特征图,然后利用3.1中确定的搜索点对坐标确定特征图层面的起始点和目的点矩阵ff,ft。

ff = torch.index_select(x, dim=2, index=ind_from.cuda(non_blocking=True))
ft = torch.index_select(x, dim=2, index=ind_to.cuda(non_blocking=True))

之后利用公式

aff = torch.exp(-torch.mean(torch.abs(ft-ff), dim=1))

得到34通道的affinity矩阵aff,每个通道表示到不同距离(指定半径范围内)的点的相似度情况。利用这个aff,shape为((1,34,1/8原图尺寸))结合3.1中得到的affinity标签,即可通过以下公式计算出loss,随后便可进行梯度下降训练。

bg_loss = torch.sum(- bg_label * torch.log(aff + 1e-5)) / bg_count
fg_loss = torch.sum(- fg_label * torch.log(aff + 1e-5)) / fg_count
neg_loss = torch.sum(- neg_label * torch.log(1. + 1e-5 - aff)) / neg_count
loss = bg_loss / 4 + fg_loss / 4 + neg_loss / 2

3.4 模型推理

利用训练得到的affinity模型,对相同的训练图片进行一个infer的过程。整个前向传播的过程和训练时的过程完全一样,就是在返回aff的时候需要对该矩阵进行稀疏化,变成(1/8原图高1/8原图宽,1/8原图高1/8原图宽)的正方形矩阵,对角线元素全为1,表示自己和自己肯定相似度最高,相关代码如下:

aff_mat = sparse.FloatTensor(torch.cat([indices, indices_id, indices_tp], dim=1),
                                         torch.cat([aff, torch.ones([area]), aff])).to_dense().cuda()

4.对原CAMs进行fine-tuning

利用第3部分最后推理得到的affinity矩阵aff_mat,将其转换成概率矩阵(归一化),然后对8倍降采样后的CAMs(第2部分得到)进行random walk处理(简单的矩阵相乘多次),变得到最后的fine-tuned的结果。

aff_mat = torch.pow(model.forward(img.cuda(), True), args.beta)
trans_mat = aff_mat / torch.sum(aff_mat, dim=0, keepdim=True)
for _ in range(args.logt):
     trans_mat = torch.matmul(trans_mat, trans_mat)
cam_full_arr = torch.from_numpy(cam_full_arr)
cam_full_arr = F.avg_pool2d(cam_full_arr, 8, 8)
cam_vec = cam_full_arr.view(2, -1)
cam_rw = torch.matmul(cam_vec.cuda(), trans_mat)
cam_rw = cam_rw.view(1, 2, dheight, dwidth)
cam_rw = torch.nn.Upsample((img.shape[2], img.shape[3]), mode='bilinear')(cam_rw)
_, cam_rw_pred = torch.max(cam_rw, 1)
res = np.uint8(cam_rw_pred.cpu().data[0])[:orig_shape[2], :orig_shape[3]]
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值