文章介绍:
文章使用了交互式分割的半监督学习方法,对视频目标进行分割。与处理三维数据的3DU-Net等3D网络不同,对视频目标进行分割时,会从每个视频中取出相邻或相近的固定个帧,训练时只使用这几个帧进行训练,测试时将全部帧加载进模型中进行分割,因此网络中test和train部分相差较大。文章中的交互式分割是在分割完成后,根据人为提示内容,对分割进行进一步修改。
代码介绍:
train_SAQ.py:
加载数据时,会根据不同用途加载不同量的数据。train时会加载sample_per_volume的帧,该变量在option中设置,而test时会加载一个volume的全部数据。这个设置也会在网络的初始化时进行,网络创建时设置内部phase为train或test。
net = STM(opt.keydim, opt.valdim, 'train',
mode=opt.mode, iou_threshold=opt.iou_threshold)
train_loss = train(trainloader,
model=net,
criterion=criterion,
optimizer=optimizer,
epoch=epoch,
use_cuda=True,
iter_size=opt.iter_size,
mode=opt.mode,
threshold=opt.iou_threshold)
if (epoch + 1) % opt.epoch_per_test == 0:
net.module.phase = 'test'
test_loss = test(testloader,
model=net.module,
criterion=criterion,
epoch=epoch,
use_cuda=True)
在train和test函数中,被选中的帧frame与它的mask分别传入model,计算损失函数
out, quality, ious = model(frame=frames, mask=masks, num_objects=objs, criterion=mask_iou_loss)
model_SAQ.py:
获得数据后,模型的训练过程分为以下几步:
memorize:
具体来说,就是使用上一个帧的分割结果以及内容,指导当前帧的分割,当没有上一个帧时,使用的就是当前帧和当前帧的mask,即当前为参考帧
if t - 1 == 0 or self.mode == 'mask':
tmp_mask = mask[idx, t - 1:t]
elif self.mode == 'recurrent':
tmp_mask = out
else:
pred_mask = out[0, 1:num_object + 1]
iou = mask_iou(pred_mask, mask[idx, t - 1, 1:num_object + 1])
if iou > self.iou_threshold:
tmp_mask = out
else:
tmp_mask = mask[idx, t - 1:t]
key, val, _ = self.memorize(frame=frame[idx, t - 1:t], masks=tmp_mask,
num_objects=num_object)
在memorize函数中,分别获取帧,标签以及背景,送入内存编码器
frame_batch = []
mask_batch = []
bg_batch = []
# print('\n')
# print(num_objects)
try:
for o in range(1, num_objects + 1): # 1 - no
frame_batch.append(frame)
mask_batch.append(masks[:, o])
for o in range(1, num_objects + 1):
bg_batch.append(torch.clamp(1.0 - masks[:, o], min=0.0, max=1.0))
# make Batch
frame_batch = torch.cat(frame_batch, dim=0)
mask_batch = torch.cat(mask_batch, dim=0)
bg_batch = torch.cat(bg_batch, dim=0)
except RuntimeError as re:
print(re)
print(num_objects)
raise re
r4, _, _, _ = self.Encoder_M(frame_batch, mask_batch, bg_batch) # no, c, h, w
从编码器中获得特征,并由特征获得键与值
k4, v4 = self.KV_M_r4(memfeat)
k4 = k4.permute(0, 2, 3, 1).contiguous().view(num_objects, -1, self.keydim)
v4 = v4.permute(0, 2, 3, 1).contiguous().view(num_objects, -1, self.valdim)
return k4, v4, r4
segment:
根据得到的key和value,使用Encoder_Q对当前帧进行分割
# segment
tmp_key = torch.cat(batch_keys, dim=1)
tmp_val = torch.cat(batch_vals, dim=1)
logits, ps, r4 = self.segment(frame=frame[idx, t:t + 1], keys=tmp_key, values=tmp_val,
num_objects=num_object, max_obj=max_obj)
r4s.append(r4)
out = torch.softmax(logits, dim=1)
tmp_out.append(out)
将所有的分割结果全部堆叠起来,获得batch_out,并根据batch_out获得当前分割的质量以及损失函数,将这三者返回后,一轮训练结束