ARN理解与MindSpore框架下的实现

ARN理解与MindSpore框架下的实现

总体介绍

  1. ARN是用于少样本动作分类的网络,实现少样本分类的关键就在于相似度学习模块的构建。这一模块采用relation network,自动学习度量函数,输入查询集(query set)样本的二维特征图和支持集(support set)中每一类样本的二维特征图沿着channel维度作级联后得到的特征对,即可输出每一个样本对的相似度得分。具体的,这一相似度度量模块主要由四层Conv3d构成,进行线性变换。

  1. ARN包含embedding、backbone、neck、head几部分,大致可以归纳为如下几点:

  • embedding部分:主要是一个类似于C3D结构的特征提取器,通过四层Conv3d进行线性变换。这一部分支持对C3D主干网络的复用,也支持直接使用Unit3d类(包含Conv3d、norm、activation)进行构建,仅需在入参is_c3d中指定True或False即可。

  • backbone部分:主要使用空间注意力模块对上一步embedding提取得到的特征进行增强,后将特征维度变换为C * N便于后续处理。空间注意力模块主要由三层Conv3d构成,进行线性变换。

  • neck部分:主要进行带有功率归一化(power normalization)的二阶池化(second-order pooling)处理,有助于实现序列无关以及少样本条件下的性能可靠度。后将特征维度变换为C * C便于特征图的拼接和特征对的构建,以实现最后基于度量的分类。

  • head部分:主要相似度学习模块,采用relation network实现,自动学习度量函数。具体如上第一点中所述。

上图为ARN网络模型的总体框架图。

细节介绍

1. 特征编码器

特征提取对任何一个深度神经网络模型都是至关重要的一步,特征提取的准确性和针对性直接影响了模型对输入数据的采样和感知能力,一定程度上决定了模型表现的“上限”。ARN采用基于C3D结构的特征编码器Conv-4-64,即4层卷积、64输出通道,是对C3D的简化和优化。输入视频数据 V (如20帧)经过编码器特征提取操作 f 后输出维度为 C*T*H*W 的特征向量 Φ ,具体如下公式所示:

具体而言,编码器由四层卷积块组成,前两层卷积块计算后分别插入一层kernel和stride均为2*2*2的最大池化计算使输入数据在 T H W 维度上均缩小为原来的1/4。其中,每个卷积块依此包含卷积核大小为3*3*3且输出通道为64的卷积计算、批标准化(Batch Normalization)处理和ReLU激活函数处理。

2. 自监督学习模块

将自监督学习模块应用于视频少样本分类,主要是为了鼓励基本网络模型学习附加的辅助任务,从而促使模型学到通用、稳定的特征提取方法,以更好地理解输入图像或视频等数据中蕴含的语义等信息。ARN的工作同时在时域和空域上引入自监督学习任务,在基于数据增强的设计思路指引下创建辅助任务(pretext)的学习。

具体的,在时域上,不同于以往工作对每一帧进行乱序操作,是对不重叠的、固定长度的帧块进行乱序操作,即随机打乱动作块在时域上的先后顺序。这样做的好处在于更不容易拆解、打破判别性强的动作块所在相关联的帧,更好地保留了连贯的动作行为信息。在空域上,可以对视频中每一帧图像进行拼图或旋转的增强操作,如拼图操作将每一帧图像分为不重叠的四个区块再随机打乱,旋转操作则是将每一帧图像随机旋转0°、90°、180°或270°。以旋转为例,设使用判别器 d 的自监督学习的目标函数为Lrot, D 为判别器 d 的参数,则该自监督任务如下公式所示:

3. 注意力机制

ARN中通过注意力机制来来实现为不同贡献度的动作块分配不同关注度,即不同的注意力权重向量。具体的,考虑到计算资源的限制,时空注意力单元被分为时间和空间两个模块,分别对每一帧图像的不同部分和不同时间段的动作块的不同贡献度作出描述,如下图所示:

将维度为 C*T*H*W 的特征向量 输入到时间注意力单元,得到 1*T*1*1 维度的注意力向量,输入到空间注意力单元则得到 1*1*H*W 维度的注意力向量。时间和空间注意力作用影响程度通过αt、αs两个参数来进行调控。设 ts 分别表示时间、空间注意力单元, Φ* 表示注意力特征图,则注意力机制的作用如下公式所示:

4. 基于增强和对齐的注意力机制

基于增强和对齐的注意力机制(Augmentation-guided attention by alignment)是ARN工作中的主要创新点,模型实现的序列无关效果也主要得益于这个机制。如下图中所示,这一机制主要通过基于数据增强的自监督任务来加强时间和空间的注意力。具体的,通过将1)原始数据经过特征编码得到的表征先经过注意力单元的处理、再进行增强操作(如拼图或旋转)得到的“增强了的注意力向量”,2)增强后的视频数据经过特征编码得到的表征经由注意力单元处理后进一步得到注意力向量(即“增强数据的注意力向量”),这两个向量进行对齐学习来实施。这样打乱和对齐的做法旨在通过数据增强建立动作块和特定注意力权重之间的映射,使注意力机制能够适应输入序列的变化,调节、拓宽模型感受的范围,从而达到序列无关的提升效果。

由此,Augmentation-guided attention by alignment的损失函数可依据以下公式

定义为公式:

5. 特征图聚合

ARN中的聚合阶段采用带有功率归一化(Power Normalization, PN)的二阶池化(Second-order Pooling)操作。二阶池化通过算子 g 对特征向量进行降维操作得到特征图 ψ

具体的,经过特殊编码器提取特征和注意力单元增强特征后得到特征向量维度为 C*T*H*W ,在将这些向量输入到二阶池化函数之前首先对其进行变形,成为 C*N ,其中 N = T*H*W ,便于后续池化计算。经过二阶池化的处理(核心为向量乘上本身的转置向量,即一个整形操作)输出得到维度为 C*C 的特征向量,即二维特征图,在实现降维的同时也通过 NT*H*W 维度上的信息,体现了对输入数据的序列无关性,且对于输入视频帧长也不具约束。

特别的,此处二阶池化操作还带有Power Normalization。从直观上看,加上PN操作的好处在于可以降低频繁出现的视觉元素对结果的贡献,而增加出现频率较低的视觉元素对结果的贡献,这也更加符合人类观察学习的特点。PN关注图像中的视觉元素是否共同出现,而非统计共同出现了几次,这大大降低了后续比较器需要记忆的数据量,而这也正是学习分类任务的自然特性之一。所以,PN引入二阶池化很适合少样本任务的学习。PN的计算如下公式所示:

6. 少样本学习——关系网络

ARN中的比较器采用关系网络(Relation Network, RN)的结构,这一部分也是该网络结构适用于少样本学习的直接体现。在ARN网络结构中,比较器RN的输入为关系描述(relation descriptors)向量,即一对对的特征对(Feature Pair)。这些特征对由二阶池化后得到的二维特征图经过在通道维度上的进行级联拼接得到,维度为 2*C*C 。具体的,将拼接操作记为 θ ,是指查询集中的特征图和支持集中每一类别的特征图分别做通道维度上的级联的过程。特别的,如果不是n-way 1shot任务,即支持集中每一个类别包含不止一个样本,则先对这些样本在通道维度上做平均(或最大)级联,得到与各个样本具有相同维度 C*C 的特征图代表这一类别的图像特征,再与查询集中样本做上述级联。

这些拼接好的特征对输入到RN(记为 r )中进行相似度的比较,输出相似度得分高的类别作为预测类别从而实现对视频的分类,如下公式所示:

RN采用均方误差(Mean Square Error, MSE)作为损失函数对网络进行训练,如下公式所示:

基于MindSpore框架的ARN模型复现

数据集

案例使用数据集:UCF101

  • UCF101数据集来自于Youtube,包含13320个动作视频段和101个动作类别。在这项工作中,数据集中的70个类被用于训练,10个类用于验证,剩余21个类用于测试。

.
└─ucf101                                    // contains 101 file folder
  ├── ApplyEyeMakeup                        // contains 145 videos
  │   ├── v_ApplyEyeMakeup_g01_c01.avi      // video file
  │   ├── v_ApplyEyeMakeup_g01_c02.avi      // video file
  │    ...
  ├── ApplyLipstick                         // contains 114 image files
  │   ├── v_ApplyLipstick_g01_c01.avi       // video file
  │   ├── v_ApplyLipstick_g01_c02.avi       // video file
  │    ...
  ├── ucfTrainTestlist                      // contains category files
  │   ├── classInd.txt                      // Category file.
  │   ├── testlist01.txt                    // split file
  │   ├── trainlist01.txt                   // split file
  ...

重要API接口介绍

API

一级目录

二级目录

三级目录

功能描述

类别(类/函数)

接口参数

参数说明

返回值

ARN

model

arn.py

/

ARN模型结构

support_num_per_class

支持集中每一类包含的样本数

模型

ARN

model

arn.py

/

ARN模型结构

query_num_per_class

查询集中每一类包含的样本数

模型

ARN

model

arn.py

/

ARN模型结构

class_num

类别数

模型

ARN

model

arn.py

/

ARN模型结构

is_c3d

特征提取是否复用c3d网络结构

模型

ARN

model

arn.py

/

ARN模型结构

in_channels

特征提取部分输入通道数

模型

ARN

model

arn.py

/

ARN模型结构

out_channels

特征提取部分输出通道数

模型

ARN

model

arn.py

/

ARN模型结构

jigsaw

时空域上jigsaw增强任务的判别器输出维度

模型

ARN

model

arn.py

/

ARN模型结构

sigma

sigma参数控制power normalization的坡度

模型

SpatialAttention

model

arn.py

/

构建空间注意力模块

in_channels

空间注意力模块的输入通道数

空间注意力向量

SpatialAttention

model

arn.py

/

构建空间注意力模块

out_channels

空间注意力模块的输出通道数

空间注意力向量

SimilarityNetwork

model

arn.py

/

构建ARN的相似度学习模块

in_channels

输入特征的通道数

表征相似度的特征向量

SimilarityNetwork

model

arn.py

/

构建ARN的相似度学习模块

out_channels

输出特征的通道数

表征相似度的特征向量

SimilarityNetwork

model

arn.py

/

构建ARN的相似度学习模块

input_size

输入特征图的尺寸

表征相似度的特征向量

SimilarityNetwork

model

arn.py

/

构建ARN的相似度学习模块

hidden_size

全连接层之间隐藏层的通道数

表征相似度的特征向量

ARN网络模型可执行训练案例

from mindspore import context, load_checkpoint, load_param_into_net
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.communication.management import init, get_rank, get_group_size

from src.utils.check_param import Validator, Rel
from src.utils.config import parse_args, Config
from src.utils.task_acc import TaskAccuracy
from src.loss.builder import build_loss
from src.schedule.builder import get_lr
from src.optim.builder import build_optimizer
from src.data.builder import build_dataset, build_transforms
from src.models import build_model


def main(pargs):
    # set config context
    config = Config(pargs.config)
    context.set_context(**config.context)

    # run distribute
    if config.train.run_distribute:
        if config.device_target == "Ascend":
            init()
        else:
            init("nccl")
        context.set_auto_parallel_context(device_num=get_group_size(),
                                          parallel_mode=ParallelMode.DATA_PARALLEL,
                                          gradients_mean=True)
        ckpt_save_dir = config.train.ckpt_path + "ckpt_" + str(get_rank()) + "/"
    else:
        ckpt_save_dir = config.train.ckpt_path

    # perpare dataset
    transforms = build_transforms(config.data_loader.train.map.operations)
    data_set = build_dataset(config.data_loader.train.dataset)
    data_set.transform = transforms
    dataset_train = data_set.run()
    Validator.check_int(dataset_train.get_dataset_size(), 0, Rel.GT)
    batches_per_epoch = dataset_train.get_dataset_size()

    # set network
    network = build_model(config.model)

    # set loss
    network_loss = build_loss(config.loss)
    # set lr
    lr_cfg = config.learning_rate
    lr_cfg.steps_per_epoch = int(batches_per_epoch / config.data_loader.group_size)
    lr = get_lr(lr_cfg)

    # set optimizer
    config.optimizer.params = network.trainable_params()
    config.optimizer.learning_rate = lr
    network_opt = build_optimizer(config.optimizer)

    if config.train.pre_trained:
        # load pretrain model
        param_dict = load_checkpoint(config.train.pretrained_model)
        load_param_into_net(network, param_dict)

    # set checkpoint for the network
    ckpt_config = CheckpointConfig(
        save_checkpoint_steps=config.train.save_checkpoint_steps,
        keep_checkpoint_max=config.train.keep_checkpoint_max)
    ckpt_callback = ModelCheckpoint(prefix=config.model_name,
                                    directory=ckpt_save_dir,
                                    config=ckpt_config)

    # init the whole Model
    model = Model(network,
                  network_loss,
                  network_opt,
                  metrics={"Accuracy": TaskAccuracy()})

    # begin to train
    print('[Start training `{}`]'.format(config.model_name))
    print("=" * 80)
    model.train(config.train.epochs,
                dataset_train,
                callbacks=[ckpt_callback, LossMonitor()],
                dataset_sink_mode=config.dataset_sink_mode)
    print('[End of training `{}`]'.format(config.model_name))


if __name__ == '__main__':
    args = parse_args()
    main(args)

ARN网络模型可执行测试案例

from mindspore import context, nn, load_checkpoint, load_param_into_net
from mindspore.train import Model

from src.utils.check_param import Validator, Rel
from src.utils.config import parse_args, Config
from src.utils.task_acc import TaskAccuracy
from src.utils.callbacks import EvalLossMonitor
from src.loss.builder import build_loss
from src.data.builder import build_dataset, build_transforms
from src.models import build_model


def infer(pargs):
    # set config context
    config = Config(pargs.config)
    context.set_context(**config.context)

    # perpare dataset
    transforms = build_transforms(config.data_loader.eval.map.operations)
    data_set = build_dataset(config.data_loader.eval.dataset)
    data_set.transform = transforms
    dataset_eval = data_set.run()
    Validator.check_int(dataset_eval.get_dataset_size(), 0, Rel.GT)

    # set network
    network = build_model(config.model)

    # set loss
    network_loss = build_loss(config.loss)

    # load pretrain model
    param_dict = load_checkpoint(config.infer.pretrained_model)
    load_param_into_net(network, param_dict)

    # Define eval_metrics.
    eval_metrics = {"Accuracy": TaskAccuracy()}
    
    # init the whole Model
    model = Model(network,
                  network_loss,
                  metrics=eval_metrics)

    # Begin to eval.
    result = model.eval(dataset_eval, callbacks=[EvalLossMonitor(model)])

    return result


if __name__ == '__main__':
    args = parse_args()
    result = infer(args)
    print(result)

参考资料

论文:

Few-shot Action Recognition with Permutation-invariant Attention (arxiv.org)

代码仓:

视频套件代码总仓(Gitee)

ARN网络模型代码子仓(Github)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值