pytorch 初始化


前言

学习用,侵删

记录,如何使用pytorch初始化。


一、初始化整个模型

pytorch 官方resnet代码

for m in self.modules():
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
    elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

其他借鉴

import torch.nn as nn
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        # if you also want for linear layers ,add one more elif condition 
def init_all(model, init_func, *params, **kwargs):
    for p in model.parameters():
        init_func(p, *params, **kwargs)

model = UNet(3, 10)
init_all(model, torch.nn.init.normal_, mean=0., std=1) 
# or
init_all(model, torch.nn.init.constant_, 1.) 

参考:

  1. pytorch官方代码:
  2. stack overflow

二、利用apply

1. 判断实例对象类型,利用isinstance()

代码如下:

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)

2. 通过类名判断

代码如下:

# takes in a module and applies the specified weight initialization
    def weights_init_uniform(m):
        classname = m.__class__.__name__
        # for every Linear layer in a model..
        if classname.find('Linear') != -1:
            # apply a uniform distribution to the weights and a bias=0
            m.weight.data.uniform_(0.0, 1.0)
            m.bias.data.fill_(0)
     # custom weights initialization called on netG and netD
	def weights_init(m):
	    classname = m.__class__.__name__
	    if classname.find('Conv') != -1:
	        nn.init.normal_(m.weight.data, 0.0, 0.02)
	    elif classname.find('BatchNorm') != -1:
	        nn.init.normal_(m.weight.data, 1.0, 0.02)
	        nn.init.constant_(m.bias.data, 0)

    model_uniform = Net()
    model_uniform.apply(weights_init_uniform)

参考:

  1. 如何自定义参数初始化方式
  2. pytorch 官方文档

总结

如何对多层网络进行初始化,没有自己造过轮子,记录学习一下。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值