基于transformer的视频实例分割网络VisTR
可执行案例参考:VisTR可执行案例
背景介绍
实例分割是计算机视觉中的基础问题之一。在静态图中的实例分割已经有很多的研究了,但对视频的实例分割研究相对较少。在实际应用场景上来说,像是自动驾驶,摄像头接受到的都是视频而非图片,因此研究对视频建模的模型有重要的意义。
本文是由美团无人车配送团队在CVPR2021上发表的一篇Oral论文:End-to-End Video Instance Segmentation with Transformers的介绍和在mindspore复现中部分API的介绍。
图像的实例分割指的是对静态图像中感兴趣的物体进行检测和分割的任务。视频是包含多帧图像的信息载体,相对于静态图像来说,视频的信息更为丰富,因而建模也更为复杂。不同于静态图像仅含有空间的信息,视频同时含有时间维度的信息,因而更接近对真实世界的刻画。其中,视频的实例分割指的是对视频中感兴趣的物体进行检测、分割和跟踪的任务。如图1所示,第一行为给定视频的多帧图像序列,第二行为视频实例分割的结果,其中相同颜色对应同一个实例。视频实例分割不光要对单帧图像中的物体进行检测和分割,而且要在多帧的维度下找到每个物体的对应关系,即对其进行关联和跟踪。
VisTR算法介绍
首先,相较于单帧图像,视频含有关于每个实例更为完备和丰富的信息,比如不同实例的轨迹和运动模态,这些信息能够帮助克服单帧实例分割任务中一些比较困难的问题,比如外观相似、物体邻近或者存在遮挡的情形等。另一方面,多帧所提供的关于单个实例更好的特征表示也有助于模型对物体进行更好的跟踪。因此,作者想要实现一个端到端对视频实例目标进行建模的框架。作者认为可以借鉴自然语言处理任务的思想,把视频实例分割建模为序列到序列的任务,即给定多帧图像作为输入,直接输出多帧的分割mask序列。其次,分割本身是像素特征之间相似度的学习,而跟踪本质是实例特征之间相似度的学习,因此理论上他们可以统一到同一个相似度学习的框架之下。基于以上的思考,作者选取了一个同时能够进行序列的建模和相似度学习的模型,即自然语言处理中的transformers模型。transformers本身可以用于seq2seq的任务,即给定一个序列,可以输入一个序列。并且该模型十分擅长对长序列进行建模,因此非常适合应用于视频领域对多帧序列的时序信息进行建模。其次,transformers的核心机制,自注意力模块(Self-attention) ,可以基于两两之间的相似度来进行特征的学习和更新,使得将像素特征之间相似度以及实例特征之间相似度统一在一个框架内实现成为可能。以上的特性使得transformers成为VIS任务的恰当选择。另外transformers已经有被应用于计算机视觉中进行目标检测的实践DETR。因此作者基于transformers设计了视频实例分割(VIS)的模型VisTR。
VisTR整体框架如图2所示。图中最左边表示输入的多帧原始图像序列(以三帧为例),右边表示输出的实例预测序列,其中相同形状对应同一帧图像的输出,相同颜色对应同一个物体实例的输出。给定多帧图像序列,首先利用卷积神经网络(CNN)进行初始图像特征的提取,然后将多帧的特征结合作为特征序列输入transformer进行建模,实现序列的输入和输出。不难看出,首先,VisTR是一个端到端的模型,即同时对多帧数据进行建模。建模的方式即:将其变为一个seq2seq的任务,输入多帧图像序列,模型可以直接输出预测的实例序列。虽然在时序维度多帧的输入和输出是有序的,但是单帧输入的实例的序列在初始状态下是无序的,这样仍然无法实现实例的跟踪关联,因此作者强制使得每帧图像输出的实例的顺序是一致的(用图中同一形状的符号有着相同的颜色变化顺序表示),这样只要找到对应位置的输出,便可自然而然实现同一实例的关联,无需任何后处理操作。为了实现此目标,需要对属于同一个实例位置处的特征进行序列维度的建模。针对性地,为了实现序列级别的监督,作者提出了Instance Sequence Matching的模块。同时为了实现序列级别的分割,作者提出了Instance Sequence Segmentation的模块。端到端的建模将视频的空间和时间特征当做一个整体,可以从全局的角度学习整个视频的信息,同时transformer所建模的密集特征序列又能够较好的保留细节的信息。
VisTR的详细网络结构如图3所示,以下是对网络的各个组成部分的介绍:
Backbone: 主要用于初始图像特征的提取。针对序列的每一帧输入图像,首先利用CNN的backbone进行初始图像特征的提取,提取的多帧图像特征沿时序和空间维度序列化为多帧特征序列。由于序列化的过程损失了像素原始的空间和时序信息,而检测和分割的任务对于位置信息十分敏感,因此作者将其原始的空间和水平位置进行编码,作为positional encoding叠加到提取的序列特征上,以保持原有的位置信息。positional encoding的方式遵照Image Transformer[7]的方式,只是将二维的原始位置信息变为了三维的位置信息,这部分在论文中有详细的说明。
Encoder: 主要用于对多帧特征序列进行整体的建模和更新。输入前面的多帧特征序列,transformer的encoder模块利用self-attention模块,通过点和点之间相似度的学习,进行序列中所有特征的融合和更新。该模块通过对时序和空间特征的整体建模,能够对属于同一个实例的特征进行更好的学习和增强。
Decoder: 主要用于解码输出预测的实例特征序列。由于encoder输入decoder的是密集的像素特征序列,为了解码出稀疏的实例特征,作者参考DETR的方式,引入instance query进行代表性的实例特征的解码。Instance query是网络自身学习的embedding参数,用于和密集的输入特征序列进行attention运算选取能够代表每个实例的特征。以处理3帧图像,每帧图像预测4个物体为例,模型一共需要12个instance query,用于解码12个实例预测。和前面的表示一致,用同样的形状表示对应同一帧图像的预测,同样的颜色表示同一个物体实例在不同帧的预测。通过这种方式,作者可以构造出每个实例的预测序列,对应为图3中的instance 1…instance 4,后续过程中模型都将单个物体实例的序列看作整体进行处理。
Instance Sequence Matching: 主要用于对输入的预测结果进行序列级别的匹配和监督。前面已经介绍了从序列的图像输入到序列的实例预测的过程。但是预测序列的顺序其实是基于一个假设的,即在帧的维度保持帧的输入顺序,而在每帧的预测中,不同实例的输出顺序保持一致。帧的顺序比较容易保持,只要控制输入和输出的顺序一致即可,但是不同帧内部实例的顺序其实是没有保证的,因此作者需要设计专门的监督模块来维持这个顺序。在通用目标检测之中,在每个位置点会有它对应的anchor,因此对应每个位置点的ground truth监督是分配好的,而在作者的模型中,实际上是没有anchor和位置的显式信息,因此对于每个输入点作者没有现成的属于哪个实例的监督信息。为了找到这个监督,并且直接在序列维度进行监督,作者提出了Instance Sequence Matching的模块,这个模块将每个实例的预测序列和标注数据中每个实例的ground truth序列进行二分匹配,利用匈牙利匹配的方式找到每个预测最近的标注数据,作为它的groud truth进行监督,进行后面的loss计算和学习。
Instance Sequence Segmentation: 主要用于获取最终的分割结果序列。前面已经介绍了seq2seq的序列预测过程,作者的模型已经能够完成序列的预测和跟踪关联。但是到目前为止,作者为每个实例找到的只是一个代表性的特征向量,而最终要解决的是分割的任务,如何将这个特征向量变为最终的mask序列,就是instance sequence segmentation模块要解决的问题。前面已经提到,实例分割本质是像素相似度的学习,因此作者初始计算mask的方式就是利用实例的预测和encode之后的特征图计算self-attention相似度,将得到的相似度图作为这个实例对应帧的初始attention mask特征。为了更好的利用时序的信息,作者将属于同一个实例的多帧的attention mask 作为mask序列输入3D卷积模块进行分割,直接得到最终的分割序列。这种方式通过利用多帧同一实例的特征对单帧的分割结果进行增强,可以最大化的发挥时序的优势。
VisTR损失函数
根据前面的描述,网络学习中需要计算损失的主要有两个地方,一个是Instance Sequence Matching阶段的匹配过程,一个是找到监督之后最终整个网络的损失函数计算过程。
Instance Sequence Matching过程的计算公式如式1所示:由于matching阶段只是用于寻找监督,而计算mask之间的距离运算比较密集,因此在此阶段作者只考虑box和预测的类别c两个因素。第一行中的yi表示对应第i个实例的ground truth序列,其中c表示类别,b表示bounding box,T表示帧数,即T帧该实例对应的类别和boundingbox序列。第二行和第三行分别表示预测序列的结果,其中p表示在ci这个类别的预测的概率,b表示预测的boundingbox。序列之间距离的运算是通过两个序列对应位置的值两两之间计算损失函数得到的,图中用Lmatch表示,对于每个预测的序列,找到Lmatch最低那个ground truth序列作为它的监督。根据对应的监督信息,就可以计算整个网络的损失函数。
由于作者的方法是将分类、检测、分割和跟踪做到一个端到端网络里,因此最终的损失函数也同时包含类别、boundingbox和mask三个方面,跟踪通过直接对序列计算损失函数体现。式2表示分割的损失函数,得到了对应的监督结果之后,作者计算对应序列之间的Dice loss和Focal loss作为mask的损失函数。
最终的损失函数如式3所示,为同时包含分类(类别概率)、检测(bounding box)以及分割(mask)的序列损失函数之和。
实验结果
为了验证方法的效果,作者在广泛使用的视频实例分割数据集YouTube-VIS上进行了实验,该数据集包含2238个训练视频,302个验证视频以及343个测试视频,以及40个物体类别。模型的评估标准包含AP和AR,以视频维度多帧mask之间的IOU作为阈值。
API调用说明
通过继承Dataset类创建的Ytvos类来加载YouTube-Vis数据集,其中对读取数据的处理使用到了pycocotools接口。数据处理主要包括了RandomHorizontalFlip,RandomResize,PhotometricDistort,RandomSizeCrop和Normalize,其中PhotometricDistort是指光度失真,主要包括了调整图像亮度,色度,对比度,饱和度以及加入噪点。
VisTR模型根据选取的backbone不同,可以分为VisTR-r50和VisTR-r101。VisTR模型主要由resnet backbone,encoder,decoder和,instance matching和instance segmentation组成。
backbone就是由一个普通的resnet组成,在该模型中只用到了resnet50和resnet101。
encoder和decoder构造类似,每个encoder都是由六层TransformerEncoderLayer组成,其中包括了MultiheadAttention类,两个Dense层和两个LayerNorm。每个decoder也是由六层TransformerDecoderLayer组成,其中包括了两个MultiheadAttention类,两个Dense层和三个LayerNorm。
instance matching部分主要就是HungarianMatcher类组成,其中有Multiou类和Hungarian类,前者主要来计算Iou,后者则是匈牙利匹配算法的实现。
instance segmentation部分分为两个阶段,image level和instance level。前者主要使用到了box attention,具体包括了MHAttentionMsp类(2d 注意力模块)和MaskHeadSmallConv类。MaskHeadSmallConv类中包括了一个dcn模块,主要通过ops.Custom自定义算子构成了TorchDeformConv类来实现。后者则是通过四个三维卷积来实现。
loss部分,主要定义了SetCriterion类来计算类别,bbox和mask loss。其中还单独定义了Dice Loss类和SigmoidFocalLoss类
class VistrCom(nn.Cell):
"""
Vistr Architecture.
"""
def __init__(self,
name: str = 'ResNet50',
train_embeding: bool = True,
num_queries: int = 360,
num_pos_feats: int = 64,
num_frames: int = 36,
temperature: int = 10000,
normalize: bool = True,
scale: float = None,
hidden_dim: int = 384,
d_model: int = 384,
nhead: int = 8,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
dim_feedforward: int = 2048,
dropout: int = 0.1,
activation: str = "relu",
normalize_before: bool = False,
return_intermediate_dec: bool = True,
aux_loss: bool = True,
num_class: int = 41):
super().__init__()
# input constant used in construct
self.num_queries = num_queries
self.num_frames = num_frames
self.aux_loss = aux_loss
num_pos_feats = hidden_dim // 3
self.normalize = normalize
dim_t = [temperature ** (2 * (i // 2) / num_pos_feats) for i in range(num_pos_feats)]
self.dim_t = Tensor(dim_t, msp.float32)
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
# embed init
if name == "ResNet101":
embeding = resnet.ResNet101()
if name == "ResNet50":
embeding = resnet.ResNet50()
self.query_embed = nn.Embedding(num_queries, hidden_dim)
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
self.input_proj = nn.Conv2d(num_channels, hidden_dim, kernel_size=1,
pad_mode='valid', has_bias=True)
for params in embeding.get_parameters():
if (not train_embeding or 'layer2' not in params.name and
'layer3' not in params.name and
'layer4' not in params.name):
params.requires_grad = False
if 'beta' in params.name:
params.requires_grad = False
if 'gamma' in params.name:
params.requires_grad = False
self.embed1 = nn.SequentialCell([embeding.conv1,
embeding.pad,
embeding.max_pool,
embeding.layer1])
self.embed2 = embeding.layer2
self.embed3 = embeding.layer3
self.embed4 = embeding.layer4
hidden_dim = d_model
# encoder
encoder_layers = nn.CellList([
TransformerEncoderLayer(d_model,
nhead,
dim_feedforward,
dropout,
activation,
normalize_before)
for _ in range(num_encoder_layers)
])
encoder_norm = nn.LayerNorm([d_model], epsilon=1e-5) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layers, encoder_norm)
# decoder
decoder_layers = nn.CellList([TransformerDecoderLayer(d_model,
nhead,
dim_feedforward,
dropout,
activation,
normalize_before)
for _ in range(num_decoder_layers)])
decoder_norm = nn.LayerNorm([d_model], epsilon=1e-5)
self.decoder = TransformerDecoder(decoder_layers,
decoder_norm,
return_intermediate=return_intermediate_dec)
# embed
self.class_embed = nn.Dense(hidden_dim,
num_class+1,
weight_init=HeUniform(),
bias_init=UniformBias([num_class+1, hidden_dim]))
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
# head init
hidden_dim = d_model
self.bbox_attention = mh_attention_map.MHAttentionMsp(hidden_dim,
hidden_dim,
nhead,
dropout=0.0)
self.mask_head = maskheadsmallconv.MaskHeadSmallConv(hidden_dim + nhead,
[1024, 512, 256],
hidden_dim)
self.insmask_head = nn.SequentialCell([nn.Conv3d(24, 12, 5,
pad_mode='pad',
has_bias=True,
padding=2),
GroupNorm3d(4, 12),
nn.ReLU(),
nn.Conv3d(12, 12, 5,
pad_mode='pad',
has_bias=True,
padding=2),
GroupNorm3d(4, 12),
nn.ReLU(),
nn.Conv3d(12, 12, 5,
pad_mode='pad',
has_bias=True,
padding=2),
GroupNorm3d(4, 12),
nn.ReLU(),
nn.Conv3d(12, 1, 1,
pad_mode='pad',
has_bias=True)
])
# ops init
self.cast = ops.Cast()
self.cumsum = ops.CumSum()
self.reshape = ops.Reshape()
self.stack = ops.Stack(axis=5)
self.concat4 = ops.Concat(axis=4)
self.transpose = ops.Transpose()
self.zeros = ops.Zeros()
self.sin = ops.Sin()
self.cos = ops.Cos()
self.fill = ops.Fill()
self.concat0 = ops.Concat(axis=0)
self.concat1 = ops.Concat(axis=1)
self.squeeze = ops.Squeeze(0)
self.concat_1 = ops.Concat(axis=-1)
self.tile = ops.Tile()
self.expand_dim = ops.ExpandDims()
self.zeros_like = ops.ZerosLike()
self.sigmoid = ops.Sigmoid()
def construct(self, x):
"""embed construct"""
x = x[0]
mask = self.zeros((x.shape[0], x.shape[2], x.shape[3]), msp.float32)
src_list = []
pos_list = []
features = []
src = self.embed1(x)
src_list.append(src)
src = self.embed2(src)
src_list.append(src)
src = self.embed3(src)
src_list.append(src)
src = self.embed4(src)
src_list.append(src)
for src in src_list:
interpolate = P.ResizeNearestNeighbor(src.shape[-2:])
ms = interpolate(mask[None])
ms = self.cast(ms, msp.bool_)[0]
features.append((src, ms))
features_pos = self.PositionEmbeddingSine(ms)
pos_list.append(features_pos)
src, ms = features[-1]
src_proj = self.input_proj(src)
src_copy = src_proj.copy()
n, c, h, w = src_proj.shape
src_proj = self.reshape(src_proj, (n//self.num_frames, self.num_frames, c, h, w))
src_proj = self.transpose(src_proj, (0, 2, 1, 3, 4))
src_proj = self.reshape(src_proj,
(src_proj.shape[0],
src_proj.shape[1],
src_proj.shape[2],
src_proj.shape[3]*src_proj.shape[4]))
ms = self.reshape(ms, (n//self.num_frames, self.num_frames, h*w))
pos_embed = self.transpose(pos_list[-1], (0, 2, 1, 3, 4))
pos_embed = self.reshape(pos_embed,
(pos_embed.shape[0],
pos_embed.shape[1],
pos_embed.shape[2],
pos_embed.shape[3]*pos_embed.shape[4]))
query_embed = self.query_embed.embedding_table
# backbone construct
bs, c, h, w = src_proj.shape
src_proj = self.transpose(self.reshape(src_proj, (bs, c, h * w)), (2, 0, 1))
pos_embed = self.transpose(self.reshape(pos_embed, (bs, c, h * w)), (2, 0, 1))
query_embed = self.tile(self.expand_dim(query_embed, 1), (1, bs, 1))
mask = self.reshape(ms, (bs, h * w))
tgt = self.zeros_like(query_embed)
memory = self.encoder(src_proj,
src_key_padding_mask=mask,
pos=pos_embed)
hs = self.decoder(tgt,
memory,
memory_key_padding_mask=mask,
pos=pos_embed,
query_pos=query_embed)
hs_t = self.transpose(hs, (0, 2, 1, 3))
memory = self.reshape(self.transpose(memory, (1, 2, 0)), (bs, c, h, w))
outputs_class = self.class_embed(hs_t)
outputs_coord = self.sigmoid(self.bbox_embed(hs_t))
if self.aux_loss:
output = self.concat_1([outputs_class, outputs_coord])
else:
output = self.concat_1([outputs_class[-1, ...], outputs_coord[-1, ...]])
# head construct
_, c, s_h, s_w = src_copy.shape
src = []
bs_f = features[-1][0].shape[0]//self.num_frames
for i in range(3):
_, c_f, h, w = features[i][0].shape
feature = self.reshape(features[i][0], (bs_f, self.num_frames, c_f, h, w))
src.append(feature)
n_f = self.num_queries//self.num_frames
outputs_seg_masks = []
# image level processing using box attention
for i in range(self.num_frames):
hs_f = hs_t[-1][:, i*n_f:(i+1)*n_f, :]
memory_f = self.reshape(memory[:, :, i, :], (bs_f, c, s_h, s_w))
mask_f = self.reshape(ms[:, i, :], (bs_f, s_h, s_w))
bbox_mask_f = self.bbox_attention(hs_f, memory_f, mask=mask_f)
seg_masks_f = self.mask_head(memory_f,
bbox_mask_f,
[src[2][:, i], src[1][:, i], src[0][:, i]])
outputs_seg_masks_f = self.reshape(seg_masks_f,
(bs_f,
n_f,
24,
seg_masks_f.shape[-2],
seg_masks_f.shape[-1]))
outputs_seg_masks.append(outputs_seg_masks_f)
frame_masks = self.concat0(outputs_seg_masks)
outputs_seg_masks = []
# instance level processing using 3D convolution
for i in range(frame_masks.shape[1]):
mask_ins = self.expand_dim(frame_masks[:, i], 0)
mask_ins = self.transpose(mask_ins, (0, 2, 1, 3, 4))
outputs_seg_masks.append(self.insmask_head(mask_ins))
outputs_seg_masks = self.transpose(self.squeeze(self.concat1(outputs_seg_masks)),
(1, 0, 2, 3))
outputs_seg_masks = self.reshape(outputs_seg_masks,
(1, self.num_queries,
outputs_seg_masks.shape[-2],
outputs_seg_masks.shape[-1]))
return output, outputs_seg_masks
def PositionEmbeddingSine(self, mask):
"""Sine encoding
"""
n, h, w = mask.shape
mask = self.reshape(mask, (n//self.num_frames, self.num_frames, h, w))
not_mask = ~mask
not_mask = self.cast(not_mask, msp.float32)
z_embed = self.cumsum(not_mask, 1)
y_embed = self.cumsum(not_mask, 2)
x_embed = self.cumsum(not_mask, 3)
if self.normalize:
eps = 1e-6
z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale
y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale
pos_x = x_embed[:, :, :, :, None] / self.dim_t
pos_y = y_embed[:, :, :, :, None] / self.dim_t
pos_z = z_embed[:, :, :, :, None] / self.dim_t
pos_x = self.stack([self.sin(pos_x[:, :, :, :, 0::2]),
self.cos(pos_x[:, :, :, :, 1::2])])
pos_x = self.reshape(pos_x, (pos_x.shape[0], pos_x.shape[1], pos_x.shape[2],
pos_x.shape[3], pos_x.shape[4]*pos_x.shape[5]))
pos_y = self.stack([self.sin(pos_y[:, :, :, :, 0::2]),
self.cos(pos_y[:, :, :, :, 1::2])])
pos_y = self.reshape(pos_y, (pos_y.shape[0], pos_y.shape[1], pos_y.shape[2],
pos_y.shape[3], pos_y.shape[4]*pos_y.shape[5]))
pos_z = self.stack([self.sin(pos_z[:, :, :, :, 0::2]),
self.cos(pos_z[:, :, :, :, 1::2])])
pos_z = self.reshape(pos_z, (pos_z.shape[0], pos_z.shape[1], pos_z.shape[2],
pos_z.shape[3], pos_z.shape[4]*pos_z.shape[5]))
pos = self.concat4((pos_z, pos_y, pos_x))
pos = self.transpose(pos, (0, 1, 4, 2, 3))
return pos