1 用法
Truncated Normal Distribution : 截断正太分布, 根据设置的截断范围, 只从截断范围内生成对应的正态分布数据
2 参数介绍
def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor:
tensor : 输入的 Tensor, n-dimensional Tensor
mean : 正太分布的均值
std : 正态分布的方差
a : 截取区域的最小值
b : 截取区域的最大值
3 示例
import torch
a = torch.ones(2, 3)
print(a)
b = torch.nn.init.trunc_normal_(a)
print(b)
c = torch.nn.init.uniform_(a, a = -1, b = 1)
print(c)
>>> tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([[ 0.3913, 0.7215, 1.6364],
[ 0.7812, 0.1816, -0.8096]])
tensor([[ 0.4666, -0.9647, 0.8345],
[-0.6525, 0.0482, -0.6055]])
4 补充
此外还有类似的函数:
torch.nn.init.constant_(tensor, val)