1 TSM算法论文和Github
2 参考https://blog.csdn.net/qq_18644873/article/details/89305928
3 如今行为动作识别,都是在探讨如何更好的描述时域信息特征。该文章在TSN基础上,提出Temporal Shift Module (TSM),既能保持高效又能有高性能。TSM模块是参考《Shift: A Zero FLOP, Zero Parameter Alternative to Spatial Convolutions》(该论文是探讨shift操作代替卷积操作,该论文还没看明白),提出了对时域进行shift操作,对于offline,对所有时域选择1/8channel数进行从前到后shift和从后到前shift;对online,选择对1/4全部进行从前到后shift,然后放到残差结构里面,即减少了数据移动操作,也提高了性能。
文章中提出的原因是,因为移动之后提高了时域的感受野,能进行更复杂的时域建模(For each inserted temporal shift module, the temporal receptive field will be enlarged by 2, as if running a convolution with the kernel size of 3 along the temporal dimension. Therefore, our TSM model has a very large temporal receptive field to conduct highly complicated temporal modeling.)
4 对于shift操作,第一个超参是移动多少,最终选定1/8channel left shift,然后1/8 channel right shift。其中shift操作选定是residual TSM,对于每个residual block,都用shift操作替代每个block中的conv1.
5 对于添加的Nonlocal操作,参照原文Nonlocal模块,对于resnet50在下图中前面4个block中,在第一个和第三个block后面增加了一个Nonlocal模块,然后对于后面6个block,在第一,三,五后面增加一个Nonlocal模块
6 代码中一些工作:
a. shift操作,其实就是将该帧特征,融入前后帧的特征信息,以增大时域感受野,当然对于shift操作,也是放在残差模块中。
class TemporalShift(nn.Module):
def __init__(self, net, n_segment=3, n_div=8, inplace=True):
super(TemporalShift, self).__init__()
self.net = net
self.n_segment = n_segment
self.fold_div = n_div
self.inplace = inplace
if inplace:
print('=> Using in-place shift...')
print('=> Using fold div: {}'.format(self.fold_div))
def forward(self, x):
x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace)
return self.net(x)
@staticmethod
def shift(x, n_segment, fold_div=3, inplace=False):
nt, c, h, w = x.size()
n_batch = nt // n_segment
x = x.view(n_batch, n_segment, c, h, w)
fold = c // fold_div
if inplace:
out = InplaceShift.apply(x, fold)
else:
out = torch.zeros_like(x)
out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
return out.view(nt, c, h, w)
a. 对于数据的稀疏采样和密集采样
def _sample_indices(self, record):
"""
:param record: VideoRecord
:return: list
"""
if self.dense_sample: # i3d dense sample
sample_pos = max(1, 1 + record.num_frames - 64)
t_stride = 64 // self.num_segments
start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
return np.array(offsets) + 1
else: # normal sample
average_duration = (record.num_frames - self.new_length + 1) // self.num_segments
if average_duration > 0:
offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration,
size=self.num_segments)
elif record.num_frames > self.num_segments:
offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))
else:
offsets = np.zeros((self.num_segments,))
return offsets + 1
b. 一般图片的数据增强操作,对于训练集采用GroupMultiScaleCrop,对于测试集则是先scale在centercrop
class GroupMultiScaleCrop(object):
def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
self.scales = scales if scales is not None else [1, .875, .75, .66]
self.max_distort = max_distort
self.fix_crop = fix_crop
self.more_fix_crop = more_fix_crop
self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
self.interpolation = Image.BILINEAR
def __call__(self, img_group):
im_size = img_group[0].size
crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
for img in crop_img_group]
return ret_img_group
def _sample_crop_size(self, im_size):
image_w, image_h = im_size[0], im_size[1]
# find a crop size
base_size = min(image_w, image_h)
crop_sizes = [int(base_size * x) for x in self.scales]
crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
pairs = []
for i, h in enumerate(crop_h):
for j, w in enumerate(crop_w):
if abs(i - j) <= self.max_distort:
pairs.append((w, h))
crop_pair = random.choice(pairs)
if not self.fix_crop:
w_offset = random.randint(0, image_w - crop_pair[0])
h_offset = random.randint(0, image_h - crop_pair[1])
else:
w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
return crop_pair[0], crop_pair[1], w_offset, h_offset
def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
return random.choice(offsets)
@staticmethod
def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
w_step = (image_w - crop_w) // 4
h_step = (image_h - crop_h) // 4
ret = list()
ret.append((0, 0)) # upper left
ret.append((4 * w_step, 0)) # upper right
ret.append((0, 4 * h_step)) # lower left
ret.append((4 * w_step, 4 * h_step)) # lower right
ret.append((2 * w_step, 2 * h_step)) # center
if more_fix_crop:
ret.append((0, 2 * h_step)) # center left
ret.append((4 * w_step, 2 * h_step)) # center right
ret.append((2 * w_step, 4 * h_step)) # lower center
ret.append((2 * w_step, 0 * h_step)) # upper center
ret.append((1 * w_step, 1 * h_step)) # upper left quarter
ret.append((3 * w_step, 1 * h_step)) # upper right quarter
ret.append((1 * w_step, 3 * h_step)) # lower left quarter
ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
return ret
c. 对于预训练模型,采用partialBN,即第一层bn冻结,开放后面bn层参数
def train(self, mode=True):
"""
Override the default train() to freeze the BN parameters
:return:
"""
super(TSN, self).train(mode)
count = 0
if self._enable_pbn and mode:
print("Freezing BatchNorm2D except the first one.")
for m in self.base_model.modules():
if isinstance(m, nn.BatchNorm2d):
count += 1
if count >= (2 if self._enable_pbn else 1):
m.eval()
# shutdown update in frozen mode
m.weight.requires_grad = False
m.bias.requires_grad = False
d. 对于不同层采用不同学习率进行训练,参考链接
def get_optim_policies(self):
first_conv_weight = []
first_conv_bias = []
normal_weight = []
normal_bias = []
lr5_weight = []
lr10_bias = []
bn = []
custom_ops = []
conv_cnt = 0
bn_cnt = 0
for m in self.modules():
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv3d):
ps = list(m.parameters())
conv_cnt += 1
if conv_cnt == 1:
first_conv_weight.append(ps[0])
if len(ps) == 2:
first_conv_bias.append(ps[1])
else:
normal_weight.append(ps[0])
if len(ps) == 2:
normal_bias.append(ps[1])
elif isinstance(m, torch.nn.Linear):
ps = list(m.parameters())
if self.fc_lr5:
lr5_weight.append(ps[0])
else:
normal_weight.append(ps[0])
if len(ps) == 2:
if self.fc_lr5:
lr10_bias.append(ps[1])
else:
normal_bias.append(ps[1])
elif isinstance(m, torch.nn.BatchNorm2d):
bn_cnt += 1
# later BN's are frozen
if not self._enable_pbn or bn_cnt == 1:
bn.extend(list(m.parameters()))
elif isinstance(m, torch.nn.BatchNorm3d):
bn_cnt += 1
# later BN's are frozen
if not self._enable_pbn or bn_cnt == 1:
bn.extend(list(m.parameters()))
elif len(m._modules) == 0:
if len(list(m.parameters())) > 0:
raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m)))
return [
{'params': first_conv_weight, 'lr_mult': 5 if self.modality == 'Flow' else 1, 'decay_mult': 1,
'name': "first_conv_weight"},
{'params': first_conv_bias, 'lr_mult': 10 if self.modality == 'Flow' else 2, 'decay_mult': 0,
'name': "first_conv_bias"},
{'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1,
'name': "normal_weight"},
{'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0,
'name': "normal_bias"},
{'params': bn, 'lr_mult': 1, 'decay_mult': 0,
'name': "BN scale/shift"},
{'params': custom_ops, 'lr_mult': 1, 'decay_mult': 1,
'name': "custom_ops"},
# for fc
{'params': lr5_weight, 'lr_mult': 5, 'decay_mult': 1,
'name': "lr5_weight"},
{'params': lr10_bias, 'lr_mult': 10, 'decay_mult': 0,
'name': "lr10_bias"},
]
TRN模型
1 TRN模型的backbone也是参照TSN模型,以代码中举例说明,前面提取特征,一样以8帧代表一个clip,得到8帧一共8x256的特征,然后用TRN模块,会从8帧中,选取[8, 7,6,5,4,3,2]分别作为子模块,对于2就是将8帧随机按顺序取其中2帧作为子模块的输入,对于所有子模块特征,应用2个卷积(先将channel变成256,在变成最终num_class),得到最终num_class特征,例如最终分类10类,得到batchx9的特征,然后将所有的子模块特征相加得到最终分类特征。
2 但是该模块,扩展性不好,对于较大输入帧数假设输入64帧,那样子模块太多,无法训练