【torch杂记】torch.nn.init.kaiming_normal_

torch.nn.init.kaiming_normal_

参考
源码
  • 这个函数就是实现这个公式

    • std = gain fan_mode \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} std=fan_mode gain
  • def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
        r"""kaiming正态分布
        """
        fan = _calculate_correct_fan(tensor, mode)
        gain = calculate_gain(nonlinearity, a)
        std = gain / math.sqrt(fan)
        with torch.no_grad():
          	# 这句是返回指定区间内随机生成的正太分布的值的 
            return tensor.normal_(0, std)
    
  • _calculate_correct_fan(tensor, mode)是算出input和output feature map的元素总数,源码为:

    • def _calculate_correct_fan(tensor, mode):
          mode = mode.lower()
          valid_modes = ['fan_in', 'fan_out']
          if mode not in valid_modes:
              raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
      		# 这里是fmap的大小
          fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
          # 根据mode选择返回数据
          return fan_in if mode == 'fan_in' else fan_out
      
    • def _calculate_fan_in_and_fan_out(tensor):
          dimensions = tensor.dim()
          if dimensions < 2:
              raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
      		# 这里相当于输出了前两维的size
          num_input_fmaps = tensor.size(1)
          num_output_fmaps = tensor.size(0)
          
          # 这里相当于计算了后两维的元素总和
          receptive_field_size = 1
          if tensor.dim() > 2:
            	# numel()的作用就是计算元素的个数
              receptive_field_size = tensor[0][0].numel()
              
          # 然后算出in/out的fmap的大小
          fan_in = num_input_fmaps * receptive_field_size
          fan_out = num_output_fmaps * receptive_field_size
      
          return fan_in, fan_out
      
    • 上面源码可以用下列例子解释:

      • 比如有tensor.size()=[3,48,11,11],前两者分布是output_channel和input_channel
      • fan_in =48*11*11=5808
      • fan_out=3*11*11=363
      • 然后根据mode匹配决定return哪个
  • 感谢评论区大佬指出错误,num_input_fmaps是用的size(1),num_output_fmaps用的size(0)

  • calculate_gain(nonlinearity, a)如果选的是relu,那么return math.sqrt(2.0),即根号2,下面是源码,其中注释给出了详细的gain值

    • def calculate_gain(nonlinearity, param=None):
          r"""Return the recommended gain value for the given nonlinearity function.
          The values are as follows:
      
          ================= ====================================================
          nonlinearity      gain
          ================= ====================================================
          Linear / Identity :math:`1`
          Conv{1,2,3}D      :math:`1`
          Sigmoid           :math:`1`
          Tanh              :math:`\frac{5}{3}`
          ReLU              :math:`\sqrt{2}`
          Leaky Relu        :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
          SELU              :math:`\frac{3}{4}`
          ================= ====================================================
      
          Args:
              nonlinearity: the non-linear function (`nn.functional` name)
              param: optional parameter for the non-linear function
      
          Examples:
              >>> gain = nn.init.calculate_gain('leaky_relu', 0.2)  # leaky_relu with negative_slope=0.2
          """
          linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
          if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
              return 1
          elif nonlinearity == 'tanh':
              return 5.0 / 3
          elif nonlinearity == 'relu':
              return math.sqrt(2.0)
          elif nonlinearity == 'leaky_relu':
              if param is None:
                  negative_slope = 0.01
              elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
                  # True/False are instances of int, hence check above
                  negative_slope = param
              else:
                  raise ValueError("negative_slope {} not a valid number".format(param))
              return math.sqrt(2.0 / (1 + negative_slope ** 2))
          elif nonlinearity == 'selu':
              return 3.0 / 4  # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
          else:
              raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
      
  • tensor.normal_(0, std)

    • 大意是返回一个张量,张量里面的随机数是从相互独立的正态分布中随机生成的。
    • 0为均值,std为标准差
  • 10
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

椰子奶糖

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值