用MindSpore复现VAN(Visual Attention Network)


这是昇腾AI创新大赛2022-昇思赛道参赛踩坑记录的第三篇。(代码在比赛结束后会开源)

VAN 是我参赛第一个成功提交的作品,从6月初参赛开始接触 MindSpore 到提交自己的第一份作品,很开心。

VAN 文章的链接以及原作者用 PyTorch 实现的代码仓链接如下:
paper: Visual Attention Network

code: Visual Attention Network

1、VAN 简述

VAN,即视觉注意力网络,利用自注意和大核卷积的优点,提出分解一个大核卷积运算来获取远程关系。

目的是解决图像的二维特性给计算机视觉中的自我注意应用带来的三个挑战:(1) 将图像视为一维序列,忽略了其二维结构;(2)二次复杂度对于高分辨率图像来说过于昂贵;(3)仅捕获空间适应性,忽略了通道适应性。

基于上述问题,作者提出了一种新的大核注意(LKA)模型,在避免上述问题的同时,实现自我注意中的自适应和远程相关性。并进一步介绍了一种基于 LKA 的新型神经网络,即视觉注意网络(VAN)。

2、如何复现?

之前昇腾开了一个赛题解读和复现指导的直播,邀请好名字大佬做了分享,这是大佬写的文章:
昇腾AI创新大赛2022-昇思赛道——参赛指南

文章里涵盖了赛事解读、代码迁移、权重迁移、混合精度训练、数据预处理、openi平台的使用以及大佬的一些避坑指南,强烈推荐去看看。要是能早点发该多好,我就能少踩一些坑(狗头)。

言归正传,复现最快的方法就是站在巨人的肩膀上。

MindSpore 的 ModelZoo 库里面已经有许多现成的框架,可以在里面找一个与赛题相近的模型,基于此进行修改。

下面是代码的大概目录结构(我的代码是基于 ModelZoo 里的 Swin-transformer 做的修改):

├── Visual_Attention_Network
  ├── README_CN.md                        // Visual Attention Network相关说明
  ├── src
      ├──configs                          // Visual Attention Network的配置文件
      ├──data                             // 数据集配置文件
          ├──imagenet.py                  // imagenet配置文件
          ├──augment                      // 数据增强函数文件
          ┕──data_utils                   // modelarts运行时数据集复制函数文件
  │   ├──models                           // 模型定义文件夹
          ┕──van                          // Visual Attention Network定义文件
  │   ├──trainers                         // 自定义TrainOneStep文件
  │   ├──tools                            // 工具文件夹
          ├──callback.py                  // 自定义回调函数,训练结束测试
          ├──cell.py                      // 一些关于cell的通用工具函数
          ├──criterion.py                 // 关于损失函数的工具函数
          ├──get_misc.py                  // 一些其他的工具函数
          ├──optimizer.py                 // 关于优化器和参数的函数
          ┕──schedulers.py                // 学习率衰减的工具函数
  ├── train.py                            // 训练文件
  ├── eval.py                             // 评估文件
  ├── export.py                           // 导出模型文件
  ├── postprocess.py                      // 推理计算精度文件
  ├── preprocess.py                       // 推理预处理图片文件

大致分为以下几步:

① 模型的修改

首先是修改模型,这一部分主要涉及 PyTorch 和 MindSpore 算子的 API 映射,可以参考我之前写的一篇文章:

MindSpore和PyTorch API映射(昇腾AI创新大赛2022-昇思赛道参赛踩坑记录)

或者 MindSpore 官方的 PyTorch 与 MindSpore API 映射

迁移第一步,可以使用 MindConverter 工具,将 PyTorch 代码一键转化为 MindSpore 代码,这是初步的迁移,能将一些比较简单的 API 映射关系进行转换,后续再进行校对调整就可以了。(当然纯手动更改也行,费点时间,这工具我也没用过)

VAN 模型迁移过程需要注意的算子主要有 nn.Conv2d,nn.Dropout,nn.BatchNorm2d 以及 PyTorch 使用的一些 Timm 库的代码。

  • nn.Conv2d 移植
# PyTorch
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)

# MindSpore
self.fc1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, pad_mode='valid', has_bias=True)
self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, pad_mode="pad", padding=9, dilation=3, group=dim, has_bias=True)

"""
PyTorch默认bias=True,而MindSpore默认has_bias=False
PyTorch默认pad_mode="zero",而MindSpore默认pad_mode="same"
所以当padding=0时,MindSpore的pad_mode要指定为"valid",当padding≠0时,Mindspore的pad_mode要指定为"pad"
"""
  • nn.Dropout 移植
# PyTorch
self.drop = nn.Dropout(drop)

# MindSpore
self.drop = nn.Dropout(keep_prob=1.0-drop)

"""
MindSpore:keep_prob 是指输入被保留的概率,1-keep_prob 是指输入被置 0 的概率
PyTorch:p 是指输入被置 0 的概率,与 MindSpore 相反
keep_prob = 1 - p
"""
  • nn.BatchNorm2d 比较

这个需要注意MindSpore 的 nn.BatchNorm2d 只对每个设备内的数据进行规范化,采用多卡训练要考虑 Batch Normalization 的范围。

mindspore.nn.SyncBatchNorm 是跨设备同步的 Batch Normalization。

# 判断多卡还是单卡,多卡采用nn.SyncBatchNorm,单卡采用nn.BatchNorm2d
if os.getenv("DEVICE_TARGET") == "Ascend" and int(os.getenv("DEVICE_NUM")) > 1:
    BatchNorm2d = nn.SyncBatchNorm
else:
    BatchNorm2d = nn.BatchNorm2d
  • PyTorch 的 nn.ModuleList 可以用 MindSpore 的 nn.CellList 代替

  • nn.Identity 和 Timm 库的 DropPath 移植

class Identity(nn.Cell):
    """Identity"""
    def construct(self, x):
        return x


class DropPath(nn.Cell):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob, ndim):
        super(DropPath, self).__init__()
        self.drop = nn.Dropout(keep_prob=1 - drop_prob)
        shape = (1,) + (1,) * (ndim + 1)
        self.ndim = ndim
        self.mask = Tensor(np.ones(shape), dtype=mstype.float32)

    def construct(self, x):
        if not self.training:
            return x
        mask = ops.Tile()(self.mask, (x.shape[0],) + (1,) * (self.ndim + 1))
        out = self.drop(mask)
        out = out * x
        return out


class DropPath1D(DropPath):
    def __init__(self, drop_prob):
        super(DropPath1D, self).__init__(drop_prob=drop_prob, ndim=1)


class DropPath2D(DropPath):
    def __init__(self, drop_prob):
        super(DropPath2D, self).__init__(drop_prob=drop_prob, ndim=2)
  • 模型检验

在成功迁移完成后,可以对比下 PyTorch 的模型和 MindSpore 的模型的参数量,做个粗略的验证。

# PyTorch
if __name__ == "__main__":
    import torch

    data = torch.tensor(np.ones([2, 3, 224, 224]), dtype=torch.float32)
    model = VAN(
        img_size=224, in_chans=3, num_classes=1000,
        embed_dims=[64, 128, 320, 512], mlp_ratios=[8, 8, 4, 4],
        depths=[3, 3, 12, 3])

    # 验证前向传播
    out = model(data)
    print(out.shape)

    # 验证参数量
    params = 0.
    for name, param in model.named_parameters():
        # 获取参数, 获取名字
        params += np.prod(param.shape)
        print(name)
    print("参数总量", params)

# MindSpore
if __name__ == "__main__":
    from mindspore import context
    
    context.set_context(mode=context.PYNATIVE_MODE)

    data = Tensor(np.ones([2, 3, 224, 224]), dtype=mindspore.float32)
    model = VAN(
        img_size=224, in_chans=3, num_classes=1000,
        embed_dims=[64, 128, 320, 512], mlp_ratios=[8, 8, 4, 4],
        depths=[3, 3, 12, 3])

    # 验证前向传播
    out = model(data)
    print(out.shape)

    # 验证参数量
    params = 0.
    for name, param in model.parameters_and_names():
    	# 获取参数, 获取名字
        params += np.prod(param.shape)
        print(name)
    print("参数总量", params)

② 参数初始化

这里先挖个坑,有空单独开篇文章讲 PyTorch 和 MindSpore 的参数初始化对比,这里就先直接给个模板,可以直接用。

# PyTorch
class VAN(nn.Module):
    def __init__(self, img_size=224, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
                mlp_ratios=[4, 4, 4, 4], drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[3, 4, 6, 3], num_stages=4, flag=False):
        super().__init__()
        
        self.apply(self._init_weights)  # 这里调用的时module的apply方法

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

# MindSpore
class VAN(nn.Cell):
    def __init__(self, img_size=224, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
                mlp_ratios=[4, 4, 4, 4], drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[3, 4, 6, 3], num_stages=4, flag=False):
        super().__init__()

        self._init_weights()  # 直接在构造函数__init__里面调用自己实现_init_weights方法进行参数初始化

    def _init_weights(self):
        for _, cell in self.cells_and_names():
            if isinstance(cell, nn.Dense):
                cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(sigma=0.02),
                                                             cell.weight.shape,
                                                             cell.weight.dtype))
                if isinstance(cell, nn.Dense) and cell.bias is not None:
                    cell.bias.set_data(weight_init.initializer(weight_init.Zero(),
                                                               cell.bias.shape,
                                                               cell.bias.dtype))
            elif isinstance(cell, (nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
                cell.gamma.set_data(weight_init.initializer(weight_init.One(),
                                                            cell.gamma.shape,
                                                            cell.gamma.dtype))
                cell.beta.set_data(weight_init.initializer(weight_init.Zero(),
                                                           cell.beta.shape,
                                                           cell.beta.dtype))
                # torch 这几个算子的 weight,bias 对应于 mindspore 的 gamma 和 beta
            elif isinstance(cell, nn.Conv2d):
                fan_out = cell.kernel_size[0] * cell.kernel_size[1] * cell.out_channels
                fan_out //= cell.group
                cell.weight.set_data(weight_init.initializer(weight_init.Normal(sigma=math.sqrt(2.0 / fan_out), mean=0),
                                                             cell.weight.shape,
                                                             cell.weight.dtype))
                if cell.bias is not None:
                    cell.bias.set_data(weight_init.initializer(weight_init.Zero(),
                                                               cell.bias.shape,
                                                               cell.bias.dtype))

MindSpore 可以直接在类的构造函数 init 里面直接调用参数初始化函数。Timm 库的 trunc_normal_,constant_,normal_ 等方法在 mindspore.common.initializer 都有可代替的。

③ 数据预处理

VAN 用的 imagenet,ModelZoo 里面提供的代码基本都能用,主要改一下对应的数据集路径,对齐下 PyTorch 那边的数据增强方式等就可以。

数据增强:插值方式、MixUp、CutMix、RandomErasing 参考 ModelZoo Swin-transformer 写法。


④ 优化器、学习策略

def get_param_groups(network):
    """ get param groups """
    decay_params = []
    no_decay_params = []
    for x in network.trainable_params():
        parameter_name = x.name
        if parameter_name.endswith(".weight"):
            # Dense or Conv's weight using weight decay
            decay_params.append(x)
        else:
            # all bias not using weight decay
            # bn weight bias not using weight decay, be carefully for now x not include LN
            no_decay_params.append(x)
    return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]


def _warmup_lr(warmup_lr, base_lr, warmup_length, epoch):
    """Linear warmup"""
    return epoch / warmup_length * (base_lr - warmup_lr) + warmup_lr


def cosine_lr(args, batch_num):
    """Get cosine lr"""
    learning_rate = []

    def _lr_adjuster(epoch):
        if epoch < args.warmup_length:
            lr = _warmup_lr(args.warmup_lr, args.base_lr, args.warmup_length, epoch)
        else:
            e = epoch - args.warmup_length
            es = args.epochs - args.warmup_length
            lr = 0.5 * (1 + np.cos(np.pi * e / es)) * args.base_lr
        return lr

params = get_param_groups(model)
batch_num = data.train_dataset.get_dataset_size()  # 返回一个epoch中的batch数
learning_rate = cosine_lr(args, batch_num)
step = int(args.start_epoch * batch_num)
accumulation_step = int(args.accumulation_step)  # args.accumulation=1
learning_rate = learning_rate[step::accumulation_step]
learning_rate = learning_rate * args.batch_size * int(os.getenv("DEVICE_NUM", args.device_num)) / 512.  # lr跟随batch_size和device_num改变

optim = AdamWeightDecay(
        params=params,
        learning_rate=learning_rate,
        beta1=args.beta[0],
        beta2=args.beta[1],
        eps=args.eps,
        weight_decay=args.weight_decay
)

上述代码提炼了优化器和学习策略设置的大概。

lr,weight_decay 等超参数可以通过 PyTorch 代码以及论文的实验过程超参数的设置获得。

还有一个小技巧是看原作者的代码仓是否提供训练好的模型,一般是 pth 文件,可以通过以下代码查看 pth 模型中的保留的状态信息,看看是否有相关超参数的值。

import torch

checkpoint = torch.load('van_base_828.pth.tar')	 # 加载模型

print(checkpoint.keys())  # 查看字典的键列表
# Out
# dict_keys(['epoch', 'arch', 'state_dict', 'optimizer', 'version', 'args', 'amp_scaler', 'metric'])

print(checkpoint['args'])  # 可以查看每个键保存的信息

⑤ 多卡训练设置

from mindspore import context
from mindspore.communication.management import init, get_rank
from mindspore.context import ParallelMode

def set_device(args):
    """Set device and ParallelMode(if device_num > 1)"""
    rank = 0
    # set context and device
    device_target = args.device_target
    device_num = int(os.environ.get("DEVICE_NUM", 1))  # 获取卡的数量

    if device_target == "Ascend":
        if device_num > 1:
            context.set_context(device_id=int(os.getenv('DEVICE_ID', 0)))
            init(backend_name='hccl')
            context.reset_auto_parallel_context()
            context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
                                              gradients_mean=True)
            # context.set_auto_parallel_context(pipeline_stages=2, full_batch=True)

            rank = get_rank()  # 获取卡号
        else:
            context.set_context(device_id=args.device_id)
    elif device_target == "GPU":
        if device_num > 1:
            init(backend_name='nccl')
            context.reset_auto_parallel_context()
            context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
                                              gradients_mean=True)
            rank = get_rank()
        else:
            context.set_context(device_id=args.device_id)
    else:
        raise ValueError("Unsupported platform.")

    return rank

mode = {
    0: context.GRAPH_MODE,
    1: context.PYNATIVE_MODE
}
context.set_context(mode=mode[args.graph_mode], device_target=args.device_target)  # 设置GRAPH_MODE或者PYNATIVE_MODE,以及设置设备类型:Ascend、GPU或者CPU
context.set_context(enable_graph_kernel=False)

rank = set_device(args)  # 分布式训练,多卡并行设置,返回当前卡号
set_seed(args.seed + rank)  # 设置随机种子,每个卡设置不同的随机种子

关于分布式并行,可以参考 MindSpore官方教程——分布式并行

⑥ 混合精度对齐(O0,O2,O3 以及自己设置黑白名单 O1)

MindSpore 混合精度的相关介绍可以参考:混合精度

PyTorch 的混合精度标准参考:PyTorch 默认混合精度设置

根据 PyTorch 的混合精度标准进行对齐就可以。

注意:Ascend910上,nn.Dense 算子和卷积算子不支持 fp32 的运算,其他类别的算子,例如 BatchNorm、LayerNorm、GELU 等,最好保持 fp32,否则可能会导致精度有明显下降。

import mindspore.nn as nn
from mindspore import dtype as mstype

from src.args import args


def do_keep_fp32(network, cell_types):
    """Cast cell to fp32 if cell in cell_types"""
    for _, cell in network.cells_and_names():
        if isinstance(cell, cell_types):
            cell.to_float(mstype.float32)


def cast_amp(net):
    """cast network amp_level"""
    if args.amp_level == "O1":
        print(f"=> using amp_level {args.amp_level}\n"
              f"=> change {args.arch} to fp16")
        net.to_float(mstype.float16)
        cell_types = (nn.GELU, nn.Softmax, nn.Conv2d, nn.Conv1d, nn.BatchNorm2d, nn.LayerNorm)
        # cell_types = (nn.GELU, nn.Softmax, nn.Conv2d, nn.Conv1d, nn.BatchNorm2d, nn.LayerNorm, nn.ReLU, nn.Dense)
        print(f"=> cast {cell_types} to fp32 back")
        do_keep_fp32(net, cell_types)
    elif args.amp_level == "O2":
        print(f"=> using amp_level {args.amp_level}\n"
              f"=> change {args.arch} to fp16")
        net.to_float(mstype.float16)
        cell_types = (nn.BatchNorm2d, nn.LayerNorm)
        print(f"=> cast {cell_types} to fp32 back")
        do_keep_fp32(net, cell_types)
    elif args.amp_level == "O3":
        print(f"=> using amp_level {args.amp_level}\n"
              f"=> change {args.arch} to fp16")
        net.to_float(mstype.float16)
    else:
        print(f"=> using amp_level {args.amp_level}")
        args.loss_scale = 1.
        args.is_dynamic_loss_scale = 0
        print(f"=> When amp_level is O0, using fixed loss_scale with {args.loss_scale}")


# 对模型进行混合精度设置
cast_amp(net)

⑦ 训练过程的信息打印

MindSpore 训练过程中的信息打印主要通过 callback 回调函数,在每个 epoch (或step,这看你自己设置) 结束时执行 Callback 类的 epoch_end 方法。

model = Model(net_with_loss, metrics={"acc", "loss"},
              eval_network=eval_network,
              eval_indexes=eval_indexes)

eval_cb = EvaluateCallBack(model, eval_dataset=data.val_dataset, src_url=train_dir,
                           train_url=os.path.join(args.train_url, "ckpt_" + str(rank)),
                           ave_freq=args.save_every)

model.train(int(args.epochs - args.start_epoch), data.train_dataset,
            callbacks=[time_cb, ckpoint_cb, loss_cb, eval_cb],
            dataset_sink_mode=args.dataset_sink_mode)

可以自定义 EvaluateCallback 类,继承 mindspore.train.callback.Callback 就可以。通过重写 epoch_end 方法,一般在里面调用 model.eval 进行模型验证,打印 loss、acc 等信息,进行模型保存和传输等。
import os
from mindspore import save_checkpoint
from mindspore.communication.management import get_rank
from mindspore.train.callback import Callback

from src.args import args


# 自定义eval callback
class EvaluateCallBack(Callback):
    """EvaluateCallBack"""

    def __init__(self, model, eval_dataset, src_url, train_url, save_freq=50):
        super(EvaluateCallBack, self).__init__()
        self.model = model
        self.eval_dataset = eval_dataset
        self.src_url = src_url
        self.train_url = train_url
        self.save_freq = save_freq
        self.best_acc = 0.
        self.rank = get_rank()

    def epoch_end(self, run_context):
        """
            Test when epoch end, save best model with best.ckpt.
        """
        cb_params = run_context.original_args()
        cur_epoch_num = cb_params.cur_epoch_num
        result = self.model.eval(self.eval_dataset)
        
        if result["acc"] > self.best_acc:
            self.best_acc = result["acc"]
            self.best_acc_path = os.path.join(self.src_url, "best.ckpt")
            save_checkpoint(cb_params.train_network, self.best_acc_path)  # 保存当前最佳模型

        print("ckpt_%s | epoch: %s acc: %s, best acc is %s" %
              (self.rank, cb_params.cur_epoch_num, result["acc"], self.best_acc), flush=True)

        if args.run_openi:
            try:
                import moxing as mox
                if cur_epoch_num % self.save_freq == 0:
                    mox.file.copy_parallel(src_url=self.src_url, dst_url=self.train_url)  # 与openi平台交互,保存模型到平台
            except Exception as e:
                print('moxing upload {} to {} failed: '.format(self.src_url, self.train_url) + str(e))   

⑧ 模型保存

from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

# 设置模型保存策略
# save_checkpoint_steps:每个多少个step保存一次
# keep_checkpoint_max:最大保存节点数
config_ck = CheckpointConfig(save_checkpoint_steps=data.train_dataset.get_dataset_size()*10,
                             keep_checkpoint_max=args.save_every)
ckpoint_cb = ModelCheckpoint(prefix=args.arch + str(rank), directory=train_dir,
                             config=config_ck)

# 设置完添加进model.train的callbacks就可以
# 要保存最佳模型可以自定义callback,上一小节的EvaluateCallBack就写了如何保存最佳模型
model.train(int(args.epochs - args.start_epoch), data.train_dataset,
            callbacks=[time_cb, ckpoint_cb, loss_cb, eval_cb],
            dataset_sink_mode=args.dataset_sink_mode)

⑨ openi 平台使用

openi 平台的使用可以参考:OpenI_Learning | 小白训练营课程

需要注意的是启智集群和智算集群的数据集路径和模型保存路径是不同的,加载数据集和保存模型的方式有所区别,可以参考上述 openi 小白训练营。

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值