以下代码有几个函数的解释:
(1)高斯误差函数erf: ,该函数是奇函数。
(2)uniform(x,y) 方法将随机生成下一个实数,它在 [x,y] 范围内。
(3)erfinv是erf函数的反函数,是逆误差函数,是把[-1,1]的数值映射到上。
(4)mul:乘法。
(5) add:加法。
(6) clamp(min=a, max=b):返回(min=a, max=b)之间的值,若大于b,则返回b,若小于a,则返回a。
代码解释如下:
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): def _no_grad_trunc_normal_(tensor, mean, std, a, b): def norm_cdf(x): #返回(1+x/根号2)/2 return (1. + math.erf(x / math.sqrt(2.))) / 2. with torch.no_grad(): l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) #将tensor从均匀分布中抽样数值进行填充,填充的数值介于(2 * l - 1, 2 * u - 1)之间 tensor.uniform_(2 * l - 1, 2 * u - 1) #逆误差函数erfinv_ tensor.erfinv_() tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) #tensor值返回(min=a, max=b)之间的值,若大于b,则返回b,若小于a,则返回a tensor.clamp_(min=a, max=b) return tensor return _no_grad_trunc_normal_(tensor, mean, std, a, b)