mxnet问题整理(二)

在mxnet中需要对conv进行修改,所以遇到了一些问题,选择难理解的问题记下来。

1. 修改完conv层函数之后,出现输出结果是null的问题

按照以下的方式来就好了

class new_conv(nn.Conv2D):
    def __init__(self, channels, kernel_size, **kwargs):
        # if isinstance(kernel_size, base.numeric_types):
        #     kernel_size = (kernel_size,)*2
        # assert len(kernel_size) == 2, "kernel_size must be a number or a list of 2 ints"
        super(new_conv, self).__init__(channels, kernel_size, **kwargs)

    def forward(self, x, *args):
        self.ctx = x.context
        self.set_params()
        return super(new_conv, self).forward(x, *args)

2. 网络模型参数获取

for key, val in self.params.items():
    # key = self.params.keys()
    if 'weight' in key:
        data = val.data()
        mask = netMask[key]

3. 网络参数的赋值(1)

网络参数的普通的赋值可以使用set_data(),但是网络参数的处理一般都会在构建的网络图之中。所以,最好的采取一种不会再autograd.record()出现错误的处理方式。

def assign_params(self, key, multiplier):
    """Sets this parameter's value on all contexts."""
    # self.shape = multiplier.shape
    out = []
    for arr in self.params[key].list_data():
        out.append((nd.multiply(arr, multiplier)).as_in_context(self.ctx))
    if len(out) == 1:
        self.params[key]._data = out #这里需要格外的注意,必须是list形式。我在这儿捯饬了好大一会
    else:
        self.params[key]._data = [nd.stack(*out, axis=1)]
    print('o')

4. 对网络参数处理的说明

对网络参数的处理,可能还会遇到其他各种各样的问题,所以渔具是啥呢?

在/usr/local/lib/python2.7/dist-packages/mxnet/gluon/parameter.py中,可以找到很多对网络参数处理的环节。比如读取都会指向这么一个函数 def _check_and_get(self, arr_list, ctx),上面提到的赋值时需要的list格式,就是无论如何都是以list的形式出现的。这就有data和ctxlist配合的问题。

def _check_and_get(self, arr_list, ctx):
    if arr_list is not None:
        if ctx is list:
            return arr_list
        if ctx is None:
            if len(arr_list) == 1:
                return arr_list[0]
            else:
                ctx = context.current_context()
        ctx_list = self._ctx_map[ctx.device_typeid&1]
        if ctx.device_id < len(ctx_list):
            idx = ctx_list[ctx.device_id]
            if idx is not None:
                return arr_list[idx]##重点看这儿,arr_list是数据,所以外面要加[],变成list形式,不然就被取多维数据第一维度的第一个index。

5. 对网络参数的赋值(2)

如果以在(1)中的方式对网络参数赋值,可以达到不至于不能构建计算图的效果。但是,系统认定是节点的权重的一个子集(也包括它本身)参与构建计算图。所以在计算梯度的时候,需要特别指定ignore_stale_grad——即忽略不在计算图中的梯度。

如果不想这样的话,还是有办法的。

netMask={}
class new_conv(nn.Conv2D):
    def __init__(self, channels, kernel_size, **kwargs):
        # if isinstance(kernel_size, base.numeric_types):
        #     kernel_size = (kernel_size,)*2
        # assert len(kernel_size) == 2, "kernel_size must be a number or a list of 2 ints"
        super(new_conv, self).__init__(channels, kernel_size, **kwargs)

    # def  forward(self, x, *args):
    #     return super(new_conv, self).forward(x, *args)
    
    def hybrid_forward(self, F, x, weight, bias=None): #使用重写hybrid_forward的方法,该方法是conv2D所继承的类中的函数,显示的调用了weight
        keys = self.params.keys()
        self.assign_mask(keys)
        for key in keys:
            if 'weight' in key:
                wmask = netMask[key]
            else:
                bmask = netMask[key]
        return super(new_conv, self).hybrid_forward(F, x, weight * wmask, bias * bmask)

6. python的类的相关

这更是关于python编程相关的,目前用到了,所以将他们写出来。

在使用mxnet的dataloader的时候,可以指定的一个参数是sampler,就是指定一个epoch下的可用样本。在不显式指定该参数的情况下,程序也会自己创建一个sampler。但是需要将给变量指定为一个可迭代的量。所以在编程的时候,有一点点小技巧,可以让这个更灵活。

class SubsetRandomSampler():
    """
    exclude is a instance of SubsetRandomSampler
    """

    def __init__(self, length, subs, exclude=None):
        self._length = length
        self._subs = subs
        self.exclude = exclude #这是一个SubsetRandomSampler的对象,用到它是为了形成互补的训练集和验证集


    def __iter__(self):
        random.seed(random.randint(1, 100))
        indices = range(self._length)
        if self.exclude:
            try:
                indices = set(indices) - set(self.exclude.out)
            except Exception, e:
                pass
            self.exclude = None
        self.out = random.sample(indices, self._subs)
        return iter(self.out) #注意这儿,返回的就是一个迭代量,所以变得更灵活一些。

    def __len__(self):
        return self._subs
在调用的时候,就可以如下使用,并确保二者为互补数据集:
train_sampler = SubsetRandomSampler(num_train, num_ratio)
valid_sampler = SubsetRandomSampler(num_train, num_train - num_ratio, train_sampler)

另外,如果类是一个可直接调用的,比如当使用该对象的时候,函数会自动跳转到forward之类的方式等这种实现的时候,就可以直接当作参数传递,程序会变的更灵活。比如这样:

def __init__(self, num_classes=10574, verbose=False, my_fun=nn.Conv2D, **kwargs):


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值