Pytorch参数初始化--默认初始化

导语 :使用pytorch搭好网络之后,没有显式地初始化网络参数,是否可以直接训练网络呢?

_ConvNd

class _ConvNd(Module):

    __constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias',
                     'padding_mode', 'output_padding', 'in_channels',
                     'out_channels', 'kernel_size']

    def __init__(self, in_channels, out_channels, kernel_size, stride,
                 padding, dilation, transposed, output_padding,
                 groups, bias, padding_mode):
        super(_ConvNd, self).__init__()
        ... ... # 此处略去几十行代码
        self.reset_parameters()  # 核心代码

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

Remark

  1. class Convkd(_ConvNd), where k = {1, 2, 3}.
  2. class ConvTransposekd(_ConvTransposeMixin, _ConvNd), where k = {1, 2, 3}.

_BatchNorm

  • class _NormBase(Module):
        """Common base of _InstanceNorm and _BatchNorm"""
        _version = 2
        __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
                         'running_mean', 'running_var', 'num_batches_tracked',
                         'num_features', 'affine']
    
        def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                     track_running_stats=True):
            super(_NormBase, self).__init__()
            ... ... # 此处略去几十行代码
            self.reset_parameters()  # 核心代码
    
        def reset_running_stats(self):
            if self.track_running_stats:
                self.running_mean.zero_()
                self.running_var.fill_(1)
                self.num_batches_tracked.zero_()
    
        def reset_parameters(self):
            self.reset_running_stats()
            if self.affine:
                init.ones_(self.weight)
                init.zeros_(self.bias)
class _BatchNorm(_NormBase):

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(_BatchNorm, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

Remarkclass BatchNormkd(_BatchNorm), where k = {1, 2, 3}.

Linear

class Linear(Module):
    
    __constants__ = ['bias', 'in_features', 'out_features']

    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        ... ... # 此处略去几十行代码
        self.reset_parameters() # 核心代码

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

结语 :使用Pytorch搭建网络后,如果不显式地指定网络参数的初始化方式,模型则采用默认的初始化方式

 

PyTorch 自定义初始化可以使得我们初始化参数时更加灵活和个性化。在使用 PyTorch 进行深度学习任务时,初始值设置是非常重要的。参数初始值一定程度上影响了算法的收敛速度和精度。因此,自定义初始化是非常有必要的。 PyTorch的torch.nn.init模块提供了一些常用的初始化方式,包括常见的随机初始化(uniform,normal等),常数初始化(zeros,ones等),以及一些比较有名的网络模型特定的初始化方式,如Xavier初始化,Kaiming初始化等。但有时候我们需要自定义的初始化方法,此时就需要自定义初始化。 我们可以使用register_parameter方法为模型中的每一个参数自定义初始化方法,如下所示: ``` class CustomModel(nn.Module): def __init__(self): super(CustomModel, self).__init__() self.weight = nn.Parameter(torch.Tensor(1, 100)) self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) model = CustomModel() ``` 在以上的代码中,我们可以看到,在模型内部通过register_parameter方法给bias参数设置值为None,表明bias参数不需要在初始化时使用模型默认初始化方式。然后在通过重载reset_parameters方法,我们自己进行参数初始化。 通过这种自定义初始化方式,我们可以方便地对网络模型中的参数进行初始化,从而达到优化模型的目的,提高算法的效果。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ReLuJie

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值