【视频分割】【深度学习】MiVOS官方Pytorch代码-S2M模块解析
MiVOS模型将交互到掩码和掩码传播分离,从而实现更高的泛化性和更好的性能。单独训练的交互模块将用户交互转换为对象掩码,本博客将讲解S2M(用户交互产生分割图)。
【解析代码地址】
文章目录
前言
在详细解析MiVOS代码之前,首要任务是成功运行MiVOS代码【win10下参考教程】,后续学习才有意义。
本博客讲解S2M(用户交互产生分割图)的功能模块,暂时不考虑Propagation(掩码传播)的功能模块,因此interactive_gui.py文件只保留了与S2M功能相关的代码。
【代码:用interactive_gui_3.1.py代替interactive_gui.py】
博主将各功能模块的代码在不同的博文中进行了详细的解析,点击【win10下参考教程】,博文的目录链接放在前言部分。
用户界面新增S2M模块
主函数新增代码
在主函数中,实例化S2M对象并加载权重。
with torch.cuda.amp.autocast(enabled=not args.no_amp):
s2m_saved = torch.load(args.s2m_model)
s2m_model = S2M().cuda().eval()
s2m_model.load_state_dict(s2m_saved)
S2MController封装了S2M模型为一个控制器。
s2m_controller = S2MController(s2m_model, num_objects, ignore_class=255)
app = QApplication(sys.argv)
ex = App(s2m_controller, images, num_objects)
__init__函数新增代码
self.undo_button = QPushButton('Undo')
self.undo_button.clicked.connect(self.on_undo)
self.reset_button = QPushButton('Reset Frame')
self.reset_button.clicked.connect(self.on_reset)
on_press函数关键代码讲解
用户鼠标点击主屏幕的图片时会触发事件,开始一次交互,交互需要创建一个交互对象,源码中有多种交互类,也就可以实例化多种交互对象,这里只介绍关于S2M模型的ScribbleInteraction交互。
在切换一个新的交互对象时,需要对上一次交互对象做complete_interaction处理。
if self.curr_interaction == 'Scribble':
# 第一次实例化交互对象或者上次交互对象不是Scribble类
if last_interaction is None or type(last_interaction) != ScribbleInteraction:
# 结束上一个交互对象
self.complete_interaction()
# Scribble类实例化
new_interaction = ScribbleInteraction(image, prev_hard_mask, (h, w),
self.s2m_controller, self.num_objects)
on_motion函数关键代码讲解
正确记录目标在轨迹图上的轨迹,push_point方法
if self.curr_interaction == 'Scribble':
# 确认当前的目标序号
obj = 0 if self.right_click else self.current_object
# 在运动轨迹图上更新运动轨迹
self.vis_map, self.vis_alpha = self.interaction.push_point(ex, ey, obj, (self.vis_map, self.vis_alpha))
on_release函数关键代码讲解
鼠标点击释放完成一次交互,保存交互产生的有用数据(end_path方法),并使用S2M模型预测(predict方法)当前目标的mask。
if self.curr_interaction == 'Scribble':
self.on_motion(event)
interaction.end_path()
self.interacted_mask = interaction.predict()
self.update_interacted_mask()
on_undo函数
回退操作,当前当前交互对象的交互记录回退(undo方法),还有不同交互对象交互记录回退(this_frame_interactions)。
其实源码中还有in_local_mode(本地模式),这里做了简化,只是普通模式。具体区别在于,本地模式中不同交互对象可以被回退到每一次交互的结果,普通模式中不同交互对象只允许被回退到最后一次交互的结果。
def on_undo(self):
if self.interaction is None: # 没有交互对象
if len(self.this_frame_interactions) > 1: # 当前图片上切换的交互对象次数
self.this_frame_interactions = self.this_frame_interactions[:-1] # 去除最近一个交互对象
self.interacted_mask = self.this_frame_interactions[-1].predict() # 获得最后一个交互对象记录产生的mask
else: # 只有或只剩一次交互记录 直接重置初始化
self.reset_this_interaction() # 重置初始化所有交互对象及其交互动作
self.interacted_mask = self.Pmasks[self.cursur].zero_() # 对第一次mask清零
else:
if self.interaction.can_undo(): # 判断当前操作是否支持回退
self.interacted_mask = self.interaction.undo() # 回退
else:
if len(self.this_frame_interactions) > 0: # 只剩一个历史交互对象
self.interaction = None # 没有交互对象了
self.interacted_mask = self.this_frame_interactions[-1].predict()
else:
self.reset_this_interaction()
self.interacted_mask = self.Pmasks[self.cursur].zero_()
on_reset函数
重置清除所有交互。
Pmasks和Pnp_masks是为方便单独拆分S2M模块的临时变量,源码中没有,后续加上Propagation模块会被该模块的变量代替。它们是同一变量的不同数据类型,一个是为了输入到S2M模型中参与预测,一个是为了方便主屏幕中显示,注意区分不要混淆。
def on_reset(self):
# 清除当前图片的mask的
self.Pmasks[self.cursur].zero_()
self.Pnp_masks[self.cursur].fill(0)
self.current_mask[self.cursur].fill(0)
# 重置初始化所有交互对象及其交互动作
self.reset_this_interaction()
self.show_current_frame()
reset_this_interaction函数
重置初始化所有交互对象及其交互动作
def reset_this_interaction(self):
self.complete_interaction()
self.clear_visualization()
self.interaction = None
self.this_frame_interactions = []
complete_interaction函数
源码中有多个交互类,因此可以实例化多个交互对象,一个交互对象有可以有多次交互行为,它本身会保留自己产生的交互记录,切换不同交互对象时,为了分清交互属于哪个对象产生,需要将交互对象也作为整体作历史记录。
这里直接讲解封装S2M模型的交互类ScribbleInteraction类。
def complete_interaction(self):
if self.interaction is not None:
self.clear_visualization()
self.interactions['annotated_frame'].append(self.cursur)
self.interactions['interact'][self.cursur].append(self.interaction)
self.this_frame_interactions.append(self.interaction)
self.interaction = None
ScribbleInteraction类关键代码讲解
在interact/interaction.py文件中。
__init__函数
ScribbleInteraction继承Interaction交互父类,ScribbleInteraction交互类的对象初始化。
def __init__(self, image, prev_mask, true_size, controller, num_objects):
"""
# image->当前帧/图
# prev_mask->上一次交互产生的mask
# true_size->帧/图宽高
# controller->控制器(分割的算法)
"""
super().__init__(image, prev_mask, true_size, controller)
self.K = num_objects # 目标个数
self.drawn_map = np.empty((self.h, self.w), dtype=np.uint8) # 记录鼠标的运动轨迹图
self.drawn_map.fill(255) # np.empty()返回一个随机元素的矩阵, fill全部设置为255
# background + k
self.curr_path = [[] for _ in range(self.K + 1)] # 对当前图片一次交互: 不同的目标分开记录不同的鼠标运动轨迹坐标
self.all_paths = [self.curr_path] # 对当前图片所有交互
self.size = 3
self.surplus_history = False
push_point函数
def push_point(self, x, y, k, vis=None):
if vis is not None:
vis_map, vis_alpha = vis
# 记录不同的目标运动轨迹坐标
selected = self.curr_path[k]
selected.append((x, y))
# 除了背景,至少需要一个目标
if len(selected) >= 2:
# 记录不同目标在轨迹图上正确的轨迹,用于S2M模型的输入
self.drawn_map = cv2.line(self.drawn_map,
(int(round(selected[-2][0])), int(round(selected[-2][1]))),
(int(round(selected[-1][0])), int(round(selected[-1][1]))),
k, thickness=self.size)
# 记录不同目标在主屏幕图片上正确的轨迹,用于main_canvas的输出
if vis is not None:
if k == 0: # 表示当前目标k是背景
vis_map = cv2.line(vis_map,
(int(round(selected[-2][0])), int(round(selected[-2][1]))),
(int(round(selected[-1][0])), int(round(selected[-1][1]))),
color_map[k], thickness=self.size)
else: # 表示当前目标k是其他前景目标
vis_map = cv2.line(vis_map,
(int(round(selected[-2][0])), int(round(selected[-2][1]))),
(int(round(selected[-1][0])), int(round(selected[-1][1]))),
color_map[k], thickness=self.size)
vis_alpha = cv2.line(vis_alpha,
(int(round(selected[-2][0])), int(round(selected[-2][1]))),
(int(round(selected[-1][0])), int(round(selected[-1][1]))),
0.75, thickness=self.size)
# Optional vis return
if vis is not None:
return vis_map, vis_alpha
end_path函数
def end_path(self):
# 将上一次交互记录的不同目标运动坐标保存并清空
self.all_paths.append(self.curr_path)
self.curr_path = [[] for _ in range(self.K + 1)]
# 备份当前运动轨迹图到历史记录,下次交互会覆盖之前的运动轨迹图
self.history.append(self.drawn_map.copy())
self.surplus_history = True
predict函数
交互对象完成一次交互时,根据上一个交互对象产生的mask图并结合它交互一次所更新的运动轨迹图,在原始图片上预测出新的mask图
def predict(self):
# 使用S2MController预测mask掩膜
self.out_prob = self.controller.interact(self.image, self.prev_mask, self.drawn_map)
self.out_mask = aggregate_wbg(self.out_prob, keep_bg=True, hard=True)
return self.out_mask
can_undo函数
判断是否支持回退操作。源码中是包含多个交互类,可以产生不同的交互对象,当每个交互对象的历史记录只有一次时不支持回退,因为回退操作是用上一次交互的预测覆盖当前预测,只有一次交互说明就上一次不存在交互或者产生交互动作的不是该对象。
def can_undo(self):
return (len(self.history) > 0) and not (
self.surplus_history and (len(self.history) < 2))
undo函数
ScribbleInteraction类的history中记录的运动轨迹图,回退历史运动轨迹图、目标运动坐标,并重新预测结果。
def undo(self):
if self.surplus_history:
self.history.pop() # 将当前的运动轨迹图从历史中删除
self.surplus_history = False
# 用上一次运动轨迹图覆盖当前运动轨迹图
self.drawn_map = self.history.pop()
# 用上一次交互记录的不同目标运动坐标覆盖此次的交互记录
self.all_paths = self.all_paths[:-2]
# 重置目标运动坐标
self.curr_path = [[] for _ in range(self.K + 1)]
# 重新预测
return self.predict()
S2MController类关键代码讲解
在interact/s2m_controller.py文件中。
__init__函数
def __init__(self, s2m_net: S2M, num_objects, ignore_class, device='cuda:0'):
self.s2m_net = s2m_net
self.num_objects = num_objects # 设定目标个数
self.ignore_class = ignore_class # 图片中可忽略的像素值,通常是255
self.device = device
interact函数
S2M模型每次只能针对一个目标预测合理的mask,因此对多个目标需要进行多轮预测。
def interact(self, image, prev_mask, scr_mask):
image = image.to(self.device, non_blocking=True)
prev_mask = prev_mask.to(self.device, non_blocking=True) # 上一次交互得到的mask掩膜
h, w = image.shape[-2:]
# 针对不同目标的mask掩膜
unaggre_mask = torch.zeros((self.num_objects, 1, h, w), dtype=torch.float32, device=image.device)
for ki in range(1, self.num_objects+1): # 不从0开始 0代表背景 不需要mask
# 针对目标k的前景,即只包括k的部分
p_srb = (scr_mask == ki).astype(np.uint8)
# 针对目标k的背景,即不包括k的部分
n_srb = ((scr_mask != ki) * (scr_mask != self.ignore_class)).astype(np.uint8)
Rs = torch.from_numpy(np.stack([p_srb, n_srb], 0)).unsqueeze(0).float().to(image.device)
# Rs [1,2,h,w] 2表示前景和背景
Rs, _ = pad_divide_by(Rs, 16, Rs.shape[-2:])
inputs = torch.cat([image, (prev_mask == ki).float().unsqueeze(0), Rs], 1)
# inputs [1,6,h,w] 6=3+1+2
unaggre_mask[ki-1] = torch.sigmoid(self.s2m_net(inputs))
# unaggre_mask [k,1,h,w]
return unaggre_mask
总结
尽可能简单、详细的介绍MiVOS中S2M模块的代码。后续会讲解S2M的网络原理和代码(deeplabv3plus_resnet50)以及MiVOS的其它模块的代码。