NNI speedup_model()代码笔记

前情提要:NNI是微软开发的调参工具,功能有很多,这里介绍其中的一个分支-模型压缩。


模型压缩流程:
1.模型prune
2.模型speedup
模型prune不多介绍,模型speedup就是根据掩码修改模型的结构,比如说通道剪枝,第N层的输出
通道数由10降到了5,那么第N+1层的输入通道数是不是要变成5呀。要保证剪完枝,网络各层还能衔接起来。代码基本就这三行,不过一般运行会出很多问题,除非speedup的模型非常简单且常规

  apply_compression_results(net, masks_file, device)
  m_speedup = ModelSpeedup(net, dummy_input, masks_file, device)
  m_speedup.speedup_model()

运行后报错:   

out_channel = out_shape[1]
IndexError: list index out of range

这样的报错,网上很难搜到解决方案,只能自己研究源代码了,m_speedup.speedup_model()的源代码如下:

    def speedup_model(self):
        """
        There are basically two steps: first, do mask/shape inference,
        second, replace modules.
        主要有两个步骤:首先,进行mask/形状推断,
        第二,替换模块。
        """
        _logger.info("start to speed up the model")
        self.initialize_speedup()
        training = self.bound_model.training
        # 设置到测试模式
        self.bound_model.train(False)
        # 假设在稀疏传播后修复冲突
        # which is more elegent 哪一个更优雅?
        fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)

        _logger.info("infer module masks...")
        self.infer_modules_masks()
        _logger.info('resolve the mask conflict')

        # 在更换模型之前加载原始权重(dict形式)
        self.bound_model.load_state_dict(self.ori_state_dict)
        _logger.info("replace compressed modules...")
        # mask冲突应该已经解决了
        self.replace_compressed_modules()
        self.bound_model.train(training)
        _logger.info("speedup done")

报错位置为:

fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)

参数解释:
self.masks: 就是模型的掩码,值由0和1构成,具体如下图所示

self.bound_model:网络模型
self.dummy_input
输入

继续往下研究代码,fix_mask_conflict的源代码如下

def fix_mask_conflict(masks, model, dummy_input, traced=None):
    """
    MaskConflict修复通道依赖项和组依赖项的掩码mask冲突。
    Parameters
    ----------
    masks : dict/str
        A dict object that stores the masks or the path of the mask file
    model : torch.nn.Module
        model to fix the mask conflict
    dummy_input : torch.Tensor/list of tensors/dict of tensors
        input example to trace the model
    traced : torch._C.torch.jit.TopLevelTracedModule
        the traced model of the target model, is this parameter is not None,
        目标模型的跟踪模型,该参数不是None,
        不使用模型和dummpy_input来获得跟踪图。
    """
    if isinstance(masks, str):
        # 如果mask是路径 则加载mask
        assert os.path.exists(masks)
        masks = torch.load(masks)
    assert len(masks) > 0, 'Mask tensor cannot be empty'
    #如果用户使用模型和伪_输入来跟踪模型,我们应该手动获取跟踪模型,这样,我们只跟踪一次模型, 
    #GroupMaskConflict和ChannelMaskConflict将重用此跟踪模型。
    if traced is None:
        assert model is not None and dummy_input is not None
        training = model.training
        # 需要跟踪eval mode
        model.eval()
        kw_args = {}
        if torch.__version__ >= '1.6.0':
            # 只有版本大于1.6.0的Pytork才有严格的选项strict 选项
            kw_args['strict'] = False
        traced = torch.jit.trace(model, dummy_input, **kw_args)
        model.train(training)
    #以下几行为修复组合通道mask冲突
    fix_group_mask = GroupMaskConflict(masks, model, dummy_input, traced)
    masks = fix_group_mask.fix_mask()
    fix_channel_mask = ChannelMaskConflict(masks, model, dummy_input, traced)
    masks = fix_channel_mask.fix_mask()
    return masks

报错位置为:

masks = fix_channel_mask.fix_mask()

1.strict(bool,可选):是否严格强制:attr:`state_dict`中的键与此模块返回的键匹配:meth:`torch.nn.Module.state_dict。默认值:`True``
2.torch.jit.trace 可以将现有模型或Python函数转换为TorchScript:class:`ScriptFunction`或:class:`ScriptModule`。您必须提供示例输入,然后我们运行函数,记录对所有张量执行的操作。

继续剥洋葱,fix_mask()源代码为:

 def fix_mask(self):
        """
        在对具有形状依赖关系的层进行mask推断之前,修复mask冲突。
        应在“加速”模块的mask推断之前调用此函数。
        仅支持结构化修剪mask。
        """
        if self.conv_prune_dim == 0:
            channel_depen = ChannelDependency(
                self.model, self.dummy_input, self.traced, self.channel_prune_type)

        else:
            channel_depen = InputChannelDependency(
                self.model, self.dummy_input, self.traced)
        ······
###########后面还有代码,仅列出一部分

报错位置为:

 channel_depen = ChannelDependency(
                self.model, self.dummy_input, self.traced, self.channel_prune_type)

继续深入:ChannelDependency是个类,继承自Dependency

class ChannelDependency(Dependency):
    def __init__(self, model, dummy_input, traced_model=None, prune_type='Filter'):
        """
        该模型分析模型中conv层之间的通道依赖关系。
        Parameters
        ----------
        model : torch.nn.Module
            要分析的模型
        data : torch.Tensor
            示例输入数据以跟踪网络架构。
        traced_model : torch._C.Graph
           如果我们已经有了目标模型的跟踪图,我们就不需要再跟踪模型了。
        prune_type: str
            此参数表示通道修剪类型:
            1)`Filter`修剪卷积层的过滤器以修剪相应的通道
            2)Batchnorm`:修剪Batchnorm层中的通道
        """
        self.prune_type = prune_type
        self.target_types = []
        if self.prune_type == 'Filter':
            self.target_types.extend(['Conv2d', 'Linear', 'ConvTranspose2d'])
        elif self.prune_type == 'Batchnorm':
            self.target_types.append('BatchNorm2d')

        super(ChannelDependency, self).__init__(
            model, dummy_input, traced_model)

报错位置为:

 super(ChannelDependency, self).__init__(
            model, dummy_input, traced_model)

继续查看父类Dependency的代码:

class Dependency:
    def __init__(self, model=None, dummy_input=None, traced_model=None):
        """
        为模型建立图
        """
        from nni.common.graph_utils import TorchModuleGraph

        # 检查输入是否合法
        if traced_model is None:
            # 用户应提供模型和虚拟输入以进行跟踪
            # 模型或已跟踪的模型
            assert model is not None and dummy_input is not None
        self.graph = TorchModuleGraph(model, dummy_input, traced_model)
        self.model = model
        self.dependency = dict()
        self.build_dependency()

    def build_dependency(self):
        raise NotImplementedError

    def export(self, filepath):
        raise NotImplementedError

报错位置为:

self.build_dependency()

看来是build_dependency()出错了,继续往下看build_dependency()代码:

    def build_dependency(self):
        """
        为模型中的conv层构建通道依赖关系。
        """
        # 在分析数据之前,手动解压缩元组/列表
        # 通道依赖性
        self.graph.unpack_manually()
        for node in self.graph.nodes_py.nodes_op:
            parent_layers = []
            # 找到包含 aten::add的节点
            # 或者 aten::cat 操作
            if node.op_type in ADD_TYPES:
                parent_layers = self._get_parent_layers(node)
            elif node.op_type == CAT_TYPE:
                #确定此cat操作是否会引入通道
                #依赖关系,我们需要cat的特定输入参数
                #操作。为了获得cat操作的输入参数,我们
                #需要遍历此NodePyGroup包含的所有cpp_节点,
                #因为,TorchModuleGraph合并了重要节点和相邻节点
                #不重要的节点(例如,以prim::attr开头的节点)进入
                #NodepyGroup。
                cat_dim = None
                for cnode in node.node_cpps:
                    if cnode.kind() == CAT_TYPE:
                        cat_dim = list(cnode.inputs())[1].toIValue()
                        break
                if cat_dim != 1:
                    parent_layers = self._get_parent_layers(node)
            dependency_set = set(parent_layers)
            #合并 dependencies
            for parent in parent_layers:
                if parent in self.dependency:
                    dependency_set.update(self.dependency[parent])
            # 保存 dependencies
            for _node in dependency_set:
                self.dependency[_node] = dependency_set

报错位置为:

parent_layers = self._get_parent_layers(node)

build_dependency()里涉及到了图,即self.graph。graph里的成员还是挺多的:

成员介绍:
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict 
key: output name, value: a node that generates this output

继续看 代码self._get_parent_layers

 def _get_parent_layers(self, node):
        """
        为目标节点查找最近的父conv层。
        Parameters
        ---------
        node : torch._C.Node
            target node.
        Returns
        -------
        parent_layers: list
            nearest father conv/linear layers for the target worknode.
        """

        parent_layers = []
        queue = []
        queue.append(node)
        while queue:
            curnode = queue.pop(0)
            if curnode.op_type in self.target_types:
                # 找到第一个相遇的conv
                parent_layers.append(curnode.name)
                continue
            elif curnode.op_type in RESHAPE_OPS:
                if reshape_break_channel_dependency(curnode):
                    continue
            parents = self.graph.find_predecessors(curnode.unique_name)
            parents = [self.graph.name_to_node[name] for name in parents]
            for parent in parents:
                queue.append(parent)

        return parent_layers

报错位置:

 if reshape_break_channel_dependency(curnode):

继续看reshape_break_channel_dependency函数代码

def reshape_break_channel_dependency(op_node):
    """
    重塑操作(reshape, view, flatten)可能会打破通道依赖性。我们需要检查这些重塑操作的输入参数, 
    以检查这个重塑节点是否会打破通道依赖性。然而,分析每个重塑函数的输入参数并推断它是否会打破通 
    道依赖性是很复杂的。所以目前,我们只是检查输入通道和输出通道是否相同,如果是,那么我们可以说 
    原始的重塑函数不想改变通道的数量,这意味着通道依赖性没有被打破。相比之下,原始的重塑操作想要更 
    改通道的数量,因此它打破了通道依赖性。
    Parameters
    ----------
    opnode: NodePyOP
        A Op node of the graph.
    Returns
    -------
    bool
        是否这个操作会打破通道依赖
    """
    in_shape = op_node.auxiliary['in_shape']
    out_shape = op_node.auxiliary['out_shape']
    in_channel = in_shape[1]
    out_channel = out_shape[1]
    return in_channel != out_channel

报错位置为

out_channel = out_shape[1]

根据代码所示,in_shape和out_shape至少有两个数据,打印in_shape
out:[[9], [9]]
打印out_shape:
out:[18]

函数解析:
判断重塑操作是否会打破通道依赖性,我这里弹出的错误是由cat操作引起的。
一般tensor的形状为[b,c,w,h],代码要判断输入输出通道是否相同,所以代码里比较的是第二维度。我这里的cat的对象不是标准tensor形状,而是一个一维数据,所以肯定没有索引为1的数据。
不过还有个问题,即使是标准的tensor[b,c,w,h]执行cat操作:

代码 out_shape[1] 依然会报错,因为out_shape只有一个成员,奇怪了?
解决方式:
一维数据不采用torch.cat拼接!并将reshape_break_channel_dependency中的

 in_channel = in_shape[1]
 out_channel = out_shape[1]

改为:

    in_channel = in_shape[0][1]
    out_channel = out_shape[1]

ok,这个问题暂时解决,但是  in_channel = in_shape[0][1]又报
'int' object is not subscriptable
这个节点里的内容一会一个样,迎合了A,B又不行了······
这次是view操作导致:

 view前,形状[18],view之后[1,18,1,1]
解决方式:
不同的重塑操作,采用不同的提取方式!

总结:
speedup出现的问题基本上都是由网络模型中采取的某种操作导致的,因为speedup编写的通用代码不可能适用于所有情况。总之,调试起来还是有点麻烦的。

后记:
后来在
 self.infer_modules_masks()时,又出现了100个错误,实在是解决不了了,换代码!

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值