2.6.2 ConvNeXt_trunc_normal_网络的代码注释

以下代码有几个函数的解释:

(1)高斯误差函数erf: erf(x)=\frac{2}{\sqrt{\pi }}\int_{0}^{x}e^{^{-t^{2}}}dt,该函数是奇函数。

(2)uniform(x,y) 方法将随机生成下一个实数,它在 [x,y] 范围内。

(3)erfinv是erf函数的反函数,是逆误差函数,是把[-1,1]的数值映射到[-\infty ,+\infty ]上。

(4)mul:乘法。

  (5)  add:加法。

  (6)  clamp(min=a, max=b):返回(min=a, max=b)之间的值,若大于b,则返回b,若小于a,则返回a。

代码解释如下:

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    def _no_grad_trunc_normal_(tensor, mean, std, a, b):
        def norm_cdf(x):
            #返回(1+x/根号2)/2
            return (1. + math.erf(x / math.sqrt(2.))) / 2.

        with torch.no_grad():

            l = norm_cdf((a - mean) / std)
            u = norm_cdf((b - mean) / std)

            #将tensor从均匀分布中抽样数值进行填充,填充的数值介于(2 * l - 1, 2 * u - 1)之间
            tensor.uniform_(2 * l - 1, 2 * u - 1)
            #逆误差函数erfinv_
            tensor.erfinv_()

            tensor.mul_(std * math.sqrt(2.))
            tensor.add_(mean)
            
            #tensor值返回(min=a, max=b)之间的值,若大于b,则返回b,若小于a,则返回a
            tensor.clamp_(min=a, max=b)
            return tensor
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值