【Pytorch 学习笔记】def init_weights() 初始化参数

在 CNN 中,经常可以看见 init_weights() 函数,它是用来初始化网络参数的。

以下面代码为例:

class LeNet(nn.Module):

    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.conv3 = nn.Conv2d(16, 120, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(120, 84)
        self.fc2 = nn.Linear(84, 10)

        self.apply(self.init_weights) # 调用初始化函数
        
    def init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.constant_(m.weight, 0.1)
        elif isinstance(m, nn.Linear):
            nn.init.constant_(m.weight, 0.2)
            nn.init.constant_(m.bias, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = LeNet()

在生成网络 net 时,会指定 net 最初的权重,对于一些预训练好的模型权重,就可以放在这个部分进行加载。

我们打印 net 中的各层,如下:

print(net.modules)

---------------------------------------------------------------------

<bound method Module.modules of LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(16, 120, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=120, out_features=84, bias=True)
  (fc2): Linear(in_features=84, out_features=10, bias=True)
)>

init_weights(self, m) 中的 m 就是指 net 中的某一层。

isinstance() 函数来判断一个对象是否是一个已知的类型,isinstance(m, nn.Conv2d) 就是

isinstance(object, classinfo)

参数:

        object – 实例对象;

        classinfo – 可以是直接或间接类名、基本类型或者由它们组成的元组。

返回值:

        如果 object 与 classinfo 的类型相同则返回 True,否则返回 False

nn.init.constant_() 是 torch.nn 中的用于填充数值的函数,这里用于指定初始化值,还有许多其他函数可用于此。

并不是所有的层都能初始化权重的,比如 nn.MaxPool2d(),它是无法初始化的。

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值