Weakly-Supervised Segmentation for Disease Localization in Chest X-Ray Images
摘要
- 像素级注释的获取成本很高,通常需要领域专业知识
- 仅以图像级类别标签为监督的医用x线胸部图像语义分割的新方法
- 通过正则化过程增强的监督分类模型生成训练图像中异常区域的伪分割标签
- 对获得的激活映射进行后处理并传播到第二个分类模型-像素间关系网络,该模型改进了不同对象类之间的边界
- 所得到的伪标签用于训练所提出的完全监督分割模型
- 代码链接
方法
三个连续的步骤:类激活映射生成、像素间关系网络映射增强和分割。在每一步之后,我们还应用一个或多个后处理技术,如条件随机场(CRF),阈值,噪声滤波(低置信度的小区域)
CAM生成
- 首先,我们在图像级标签上训练全监督分类模型。这一步测试的两个架构是ResNet50 和VGG,另外还有三个卷积层,然后是ReLU激活。在最后的卷积层中,我们还将stride替换为膨胀,以增加最终特征图的大小,同时将输出stride从32减小到8。
def resnet50(pretrained, num_classes, replace_stride_with_dilation, weights_path):
model = ResNet50(
num_classes if num_classes != 2 else 1, replace_stride_with_dilation
)
if pretrained:
if weights_path is not None:
state_dict = torch.load(weights_path)["model"]
else:
state_dict = load_state_dict_from_url(resnet50_url, progress=True)
if num_classes != 1000: # if not default configuratipn:
del state_dict["fc.weight"]
del state_dict["fc.bias"]
model.load_state_dict(state_dict, strict=False)
return model
class ResNet50(nn.Sequential):
def __init__(self, num_classes, replace_stride_with_dilation):
model_modules = list(
models.resnet50(
pretrained=False,
num_classes=num_classes,
replace_stride_with_dilation=replace_stride_with_dilation,
).named_children()
)
model_modules.insert(-1, ("flatten", Flatten()))
super().__init__(OrderedDict(model_modules))
def get_params_groups(self):
feature_extractor_params = []
for name, module in self.named_children():
if name != "fc":
feature_extractor_params.extend(module.parameters())
else:
fc_params = module.parameters()
return [feature_extractor_params, fc_params]
- 通过加入正则化项来提高分类性能,灵感来自FickleNet。使用DropBlock -一个dropout技术,据我们所知,在以前的弱监督分割工作中还没有尝试过。然后将训练好的模型应用grad - cam++方法检索激活图。得到的映射作为分割任务的伪标签。
IRNet
IRNet采用生成的CAM并训练两个输出分支,分别预测位移矢量场和类边界图。
从相同的共享ResNet50主干网的所有五个级别获取特征映射。IRNet的主要优点是它能够改进不同对象类之间的边界。我们在生成的地图上训练它,因此不需要额外的监督。这一步可以让我们在分割前获得更好的伪标签。据我们所知,这种方法以前没有在医学成像领域使用过。
import torch
import torch.nn as nn
import torch.nn.functional as F
from net import resnet50
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# backbone
self.resnet50 = resnet50.resnet50(pretrained=True, strides=[2, 2, 2, 1])
self.stage1 = nn.Sequential(self.resnet50.conv1, self.resnet50.bn1, self.resnet50.relu, self.resnet50.maxpool)
self.stage2 = nn.Sequential(self.resnet50.layer1)
self.stage3 = nn.Sequential(self.resnet50.layer2)
self.stage4 = nn.Sequential(self.resnet50.layer3)
self.stage5 = nn.Sequential(self.resnet50.layer4)
self.mean_shift = Net.MeanShift(2)
# branch: class boundary detection
self.fc_edge1 = nn.Sequential(
nn.Conv2d(64, 32, 1, bias=False),
nn.GroupNorm(4, 32),
nn.ReLU(inplace=True),
)
self.fc_edge2 = nn.Sequential(
nn.Conv2d(256, 32, 1, bias=False),
nn.GroupNorm(4, 32),
nn.ReLU(inplace=True),
)
self.fc_edge3 = nn.Sequential(
nn.Conv2d(512, 32, 1, bias=False),
nn.GroupNorm(4, 32),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.ReLU(inplace=True),
)
self.fc_edge4 = nn.Sequential(
nn.Conv2d(1024, 32, 1, bias=False),
nn.GroupNorm(4, 32),
nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
nn.ReLU(inplace=True),
)
self.fc_edge5 = nn.Sequential(
nn.Conv2d(2048, 32, 1, bias=False),
nn.GroupNorm(4, 32),
nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
nn.ReLU(inplace=True),
)
self.fc_edge6 = nn.Conv2d(160, 1, 1, bias=True)
# branch: displacement field
self.fc_dp1 = nn.Sequential(
nn.Conv2d(64, 64, 1, bias=False),
nn.GroupNorm(8, 64),
nn.ReLU(inplace=True),
)
self.fc_dp2 = nn.Sequential(
nn.Conv2d(256, 128, 1, bias=False),
nn.GroupNorm(16, 128),
nn.ReLU(inplace=True),
)
self.fc_dp3 = nn.Sequential(
nn.Conv2d(512, 256, 1, bias=False),
nn.GroupNorm(16, 256),
nn.ReLU(inplace=True),
)
self.fc_dp4 = nn.Sequential(
nn.Conv2d(1024, 256, 1, bias=False),
nn.GroupNorm(16, 256),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.ReLU(inplace=True),
)
self.fc_dp5 = nn.Sequential(
nn.Conv2d(2048, 256, 1, bias=False),
nn.GroupNorm(16, 256),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.ReLU(inplace=True),
)
self.fc_dp6 = nn.Sequential(
nn.Conv2d(768, 256, 1, bias=False),
nn.GroupNorm(16, 256),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.ReLU(inplace=True),
)
self.fc_dp7 = nn.Sequential(
nn.Conv2d(448, 256, 1, bias=False),
nn.GroupNorm(16, 256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 2, 1, bias=False),
self.mean_shift
)
self.backbone = nn.ModuleList([self.stage1, self.stage2, self.stage3, self.stage4, self.stage5])
self.edge_layers = nn.ModuleList([self.fc_edge1, self.fc_edge2, self.fc_edge3, self.fc_edge4, self.fc_edge5, self.fc_edge6])
self.dp_layers = nn.ModuleList([self.fc_dp1, self.fc_dp2, self.fc_dp3, self.fc_dp4, self.fc_dp5, self.fc_dp6, self.fc_dp7])
class MeanShift(nn.Module):
def __init__(self, num_features):
super(Net.MeanShift, self).__init__()
self.register_buffer('running_mean', torch.zeros(num_features))
def forward(self, input):
if self.training:
return input
return input - self.running_mean.view(1, 2, 1, 1)
def forward(self, x):
x1 = self.stage1(x).detach()
x2 = self.stage2(x1).detach()
x3 = self.stage3(x2).detach()
x4 = self.stage4(x3).detach()
x5 = self.stage5(x4).detach()
edge1 = self.fc_edge1(x1)
edge2 = self.fc_edge2(x2)
edge3 = self.fc_edge3(x3)[..., :edge2.size(2), :edge2.size(3)]
edge4 = self.fc_edge4(x4)[..., :edge2.size(2), :edge2.size(3)]
edge5 = self.fc_edge5(x5)[..., :edge2.size(2), :edge2.size(3)]
edge_out = self.fc_edge6(torch.cat([edge1, edge2, edge3, edge4, edge5], dim=1))
dp1 = self.fc_dp1(x1)
dp2 = self.fc_dp2(x2)
dp3 = self.fc_dp3(x3)
dp4 = self.fc_dp4(x4)[..., :dp3.size(2), :dp3.size(3)]
dp5 = self.fc_dp5(x5)[..., :dp3.size(2), :dp3.size(3)]
dp_up3 = self.fc_dp6(torch.cat([dp3, dp4, dp5], dim=1))[..., :dp2.size(2), :dp2.size(3)]
dp_out = self.fc_dp7(torch.cat([dp1, dp2, dp_up3], dim=1))
return edge_out, dp_out
def trainable_parameters(self):
return (tuple(self.edge_layers.parameters()),
tuple(self.dp_layers.parameters()))
def train(self, mode=True):
super().train(mode)
self.backbone.eval()
class AffinityDisplacementLoss(Net):
path_indices_prefix = "path_indices"
def __init__(self, path_index):
super(AffinityDisplacementLoss, self).__init__()
self.path_index = path_index
self.n_path_lengths = len(path_index.path_indices)
for i, pi in enumerate(path_index.path_indices):
self.register_buffer(AffinityDisplacementLoss.path_indices_prefix + str(i), torch.from_numpy(pi))
self.register_buffer(
'disp_target',
torch.unsqueeze(torch.unsqueeze(torch.from_numpy(path_index.search_dst).transpose(1, 0), 0), -1).float())
def to_affinity(self, edge):
aff_list = []
edge = edge.view(edge.size(0), -1)
for i in range(self.n_path_lengths):
ind = self._buffers[AffinityDisplacementLoss.path_indices_prefix + str(i)]
ind_flat = ind.view(-1)
dist = torch.index_select(edge, dim=-1, index=ind_flat)
dist = dist.view(dist.size(0), ind.size(0), ind.size(1), ind.size(2))
aff = torch.squeeze(1 - F.max_pool2d(dist, (dist.size(2), 1)), dim=2)
aff_list.append(aff)
aff_cat = torch.cat(aff_list, dim=1)
return aff_cat
def to_pair_displacement(self, disp):
height, width = disp.size(2), disp.size(3)
radius_floor = self.path_index.radius_floor
cropped_height = height - radius_floor
cropped_width = width - 2 * radius_floor
disp_src = disp[:, :, :cropped_height, radius_floor:radius_floor + cropped_width]
disp_dst = [disp[:, :, dy:dy + cropped_height, radius_floor + dx:radius_floor + dx + cropped_width]
for dy, dx in self.path_index.search_dst]
disp_dst = torch.stack(disp_dst, 2)
pair_disp = torch.unsqueeze(disp_src, 2) - disp_dst
pair_disp = pair_disp.view(pair_disp.size(0), pair_disp.size(1), pair_disp.size(2), -1)
return pair_disp
def to_displacement_loss(self, pair_disp):
return torch.abs(pair_disp - self.disp_target)
def forward(self, *inputs):
x, return_loss = inputs
edge_out, dp_out = super().forward(x)
if return_loss is False:
return edge_out, dp_out
aff = self.to_affinity(torch.sigmoid(edge_out))
pos_aff_loss = (-1) * torch.log(aff + 1e-5)
neg_aff_loss = (-1) * torch.log(1. + 1e-5 - aff)
pair_disp = self.to_pair_displacement(dp_out)
dp_fg_loss = self.to_displacement_loss(pair_disp)
dp_bg_loss = torch.abs(pair_disp)
return pos_aff_loss, neg_aff_loss, dp_fg_loss, dp_bg_loss
class EdgeDisplacement(Net):
def __init__(self, crop_size=512, stride=4):
super(EdgeDisplacement, self).__init__()
self.crop_size = crop_size
self.stride = stride
def forward(self, x):
feat_size = (x.size(2)-1)//self.stride+1, (x.size(3)-1)//self.stride+1
x = F.pad(x, [0, self.crop_size-x.size(3), 0, self.crop_size-x.size(2)])
edge_out, dp_out = super().forward(x)
edge_out = edge_out[..., :feat_size[0], :feat_size[1]]
dp_out = dp_out[..., :feat_size[0], :feat_size[1]]
edge_out = torch.sigmoid(edge_out[0]/2 + edge_out[1].flip(-1)/2)
dp_out = dp_out[0]
return edge_out, dp_out
分割
对于分割步骤,我们训练了具有不同骨干的DeepLabv3+和U-Net模型,这些模型已被证明在医学图像的全监督语义分割中产生可靠的结果。使用的骨干网包括ResNet50和SEResNeXt50。我们通过给正类增加权重来修改分割过程中的二进制交叉熵(BCE)损失,以防止向正常情况过度拟合。