Pytorch学习(3) —— nn.Parameter nn.ParameterList nn.ParameterDict 源码解析

为了更好理解Pytorch基本类的实现方法,我这里给出了关于参数方面的3个类的源码详解。

此部分可以更好的了解实现逻辑结构,有助于后续代码理解,学pytorch的话这个不是必须掌握的,看不懂也没关系。

1 Parameter 参数类源码

此部分参考《pytorch源码阅读系列之Parameter类》《通俗的讲解Python中的__new__()方法》

因为Parameter继承于torch.Tensor,没有新的变量和添加函数,只是对一些辅助函数进行了定义

Parameter作为Module类的参数,可以自动的添加到Module类的参数列表中,并且可以使用Module.parameters()提供的迭代器获取到,所以这个类是一切网络结构数据的核心。

class Parameter(torch.Tensor):
    # 这个方法比__init__方法更先执行,这里就理解为一种初始化方法
    # 详细参考《通俗的讲解Python中的__new__()方法》
    def __new__(cls, data=None, requires_grad=True):
        if data is None:
            data = torch.Tensor()
        return torch.Tensor._make_subclass(cls, data, requires_grad)

    # 为了方便实用deepcopy方法,对当前数据进行深拷贝,正常的copy方法只拷贝一层,
    # 简单的来说list的list,最好用深拷贝。
    def __deepcopy__(self, memo):
        if id(self) in memo:
            return memo[id(self)]
        else:
            result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
            memo[id(self)] = result
            return result

    # 一种可视化方法,给print使用
    def __repr__(self):
        return 'Parameter containing:\n' + super(Parameter, self).__repr__()

    # 用于替代reduce方法
    def __reduce_ex__(self, proto):
        return (
            torch._utils._rebuild_parameter,
            (self.data, self.requires_grad, OrderedDict())
        )

2 ParameterList 参数列表类源码

这个类实际上是将一个Parameter的List转为ParameterList,如下例所示[nn.Parameter(torch.randn(10, 10)) for i in range(10)]类型是List,List的每个元素是Parameter,然后这个List作为参数传入这个类构造ParameterList类型。

ParameterList输入一定是一个Parameter的List,其他类型会报错,在注册时候就会提示元素不是Parameter类型。

parms = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])

下面是对应的源码。

class ParameterList(Module):
    def __init__(self, parameters=None): # parameters 是一个python list 类型
        super(ParameterList, self).__init__()

        # 这里的+=运算是经过重载的,__iadd__定义,可以看出,实际上是调用了extend方法,
        # 将parameters 注册到_parameters中
        if parameters is not None:
            self += parameters 

    # 针对非slice的index,判断是否满足取值的条件,并返回对应角标字符串
    def _get_abs_string_index(self, idx):
        idx = operator.index(idx) # 判断输入角标的位置是否为整数
        
        # 这里重载的__len__,返回_parameters的个数
        if not (-len(self) <= idx < len(self)): # 判断是否在参数范围内
            raise IndexError('index {} is out of range'.format(idx))
        if idx < 0: #对于负号的问题,就选择倒数的元素
            idx += len(self)
        return str(idx) # 返回对应的角标字符串
    
    # 使得这个类可以通过角标访问,比如P[i]这种
    def __getitem__(self, idx):
        if isinstance(idx, slice): # 判断这个角标是否为切片就是P[i:j]这种
            # _parameters是OrderedDict类型,返回值转换为list,
            # __class__表示转换为当前类型,所以,通过切片返回的List仍然是ParameterList类型
            return self.__class__(list(self._parameters.values())[idx])
        else:
            idx = self._get_abs_string_index(idx) # 检验角标正确性
            return self._parameters[str(idx)] # 返回一个数据,数据类型为Parameter

    # 使得这个类可以通过角标访问,比如P[i] = Q这种,这里面不支持切片复制
    def __setitem__(self, idx, param):
        idx = self._get_abs_string_index(idx)
        return self.register_parameter(str(idx), param)

    # 重载len用法,可以使用len(P)统计list个数
    def __len__(self):
        return len(self._parameters)

    # 重载迭代器算法,可以用于 for i in P这种
    def __iter__(self):
        return iter(self._parameters.values())

    # 重载自加算法,比如P += Q,等价于 P.extend(Q)
    def __iadd__(self, parameters):
        return self.extend(parameters)
    
    # 列出这个类所有的属性,重载后可以使用dir(P)
    def __dir__(self):
        keys = super(ParameterList, self).__dir__()
        keys = [key for key in keys if not key.isdigit()]
        return keys

   # 在list末尾添加一个Parameter
   def append(self, parameter):
        self.register_parameter(str(len(self)), parameter)
        return self

   # 在list末尾添加一个Parameter的list,也可以是ParameterList类型
   def extend(self, parameters):
        if not isinstance(parameters, container_abcs.Iterable):
            raise TypeError("ParameterList.extend should be called with an "
                            "iterable, but got " + type(parameters).__name__)
        offset = len(self)
        for i, param in enumerate(parameters):
            self.register_parameter(str(offset + i), param)
        return self

    # 可以理解为ParameterList可视化方法,下面给一个调用的例子
    # (0): Parameter containing: [torch.FloatTensor of size 10x10]
    # (1): Parameter containing: [torch.FloatTensor of size 10x10]
    # (2): Parameter containing: [torch.FloatTensor of size 10x10]
    def extra_repr(self):
        child_lines = []
        for k, p in self._parameters.items():
            size_str = 'x'.join(str(size) for size in p.size())
            device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
            parastr = 'Parameter containing: [{} of size {}{}]'.format(
                torch.typename(p.data), size_str, device_str)
            child_lines.append('  (' + str(k) + '): ' + parastr)
        tmpstr = '\n'.join(child_lines)
        return tmpstr

3 ParameterDict 参数字典类源码

ParameterDict 是一个字典类源码,与python的字典非常相似,下面就是字典的一个例子,输入参数是个普通字典,然后转换为ParameterDict类型。

params = nn.ParameterDict({ 'left': nn.Parameter(torch.randn(5, 10)), 'right': nn.Parameter(torch.randn(5, 10))})

下面给出这个类的源码,并对其进行详细分析理解。

class ParameterDict(Module):
    def __init__(self, parameters=None):
        super(ParameterDict, self).__init__()
        if parameters is not None:
            self.update(parameters) # 更新字典

    def __getitem__(self, key): # 同上一节,可以使用键访问值
        return self._parameters[key]

    def __setitem__(self, key, parameter): # 同上一节,可以使用键设置值
        self.register_parameter(key, parameter)

    def __delitem__(self, key): # 删除某个键,可使用del删除
        del self._parameters[key]
    
    def __len__(self): # 返回字典个数
        return len(self._parameters)

    def __iter__(self): # 同上一节,可以得到迭代器,迭代器用键表示
        return iter(self._parameters.keys())

    def __contains__(self, key): # 判断当前key是否在字典中,重载关键字in, key in dict
        return key in self._parameters

    def clear(self): # 清空字典
        self._parameters.clear()
    
    def pop(self, key): # 删除某个键,并返回其值。
        v = self[key]
        del self[key]
        return v

    def keys(self): # 返回所有的键的名称
        return self._parameters.keys()

    def items(self): # 同字典的item用法
        return self._parameters.items()

    def values(self): # 返回所有的值
        r"""Return an iterable of the ParameterDict values.
        """
        return self._parameters.values()

    def update(self, parameters): # 输入新的字典,更新当前的参数字典
        if not isinstance(parameters, container_abcs.Iterable): # 保证输入一定是个字典
            raise TypeError("ParametersDict.update should be called with an "
                            "iterable of key/value pairs, but got " +
                            type(parameters).__name__)

        if isinstance(parameters, container_abcs.Mapping): # 判断是不是一个Mapping类型
            if isinstance(parameters, (OrderedDict, ParameterDict)): #判断是不是已知类型
                for key, parameter in parameters.items():
                    self[key] = parameter
            else:
                for key, parameter in sorted(parameters.items()):
                    self[key] = parameter
        else:
            # 感觉这里是为了适应其他的字典类,毕竟有可能用户自己也写个字典类
            for j, p in enumerate(parameters): 
                if not isinstance(p, container_abcs.Iterable):
                    raise TypeError("ParameterDict update sequence element "
                                    "#" + str(j) + " should be Iterable; is" +
                                    type(p).__name__)
                if not len(p) == 2:
                    raise ValueError("ParameterDict update sequence element "
                                     "#" + str(j) + " has length " + str(len(p)) +
                                     "; 2 is required")
                self[p[0]] = p[1]

    # 字典可视化
    # (left): Parameter containing: [torch.FloatTensor of size 5x10]
    # (right): Parameter containing: [torch.FloatTensor of size 5x10]
    def extra_repr(self):
        child_lines = []
        for k, p in self._parameters.items():
            size_str = 'x'.join(str(size) for size in p.size())
            device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
            parastr = 'Parameter containing: [{} of size {}{}]'.format(
                torch.typename(p.data), size_str, device_str)
            child_lines.append('  (' + k + '): ' + parastr)
        tmpstr = '\n'.join(child_lines)
        return tmpstr

总结

关于参数的三个类的分析就到这里了,其实感觉跟正常的python用法也没啥区别,为了方便用户使用pytorch,官方重载了大量的函数,方便用户使用,很大程度上降低了使用难度。后续,我再对模型的几个类比如Sequential,ModuleList,ModuleDict进行分析,Module这个类我估计不会进行分析了,将近1000行,实现了太多太多功能,我觉得太底层了,就不分析了,如果有人感兴趣的话,欢迎一起讨论研究。

  • 16
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值