pytorch系列 ---9的番外, Xavier和kaiming是如何fan_in和fan_out的,_calculate_fan_in_and_fan_out解读 Conv2d

本文详细介绍了如何在PyTorch中计算线性层(Linear)和卷积层(Conv2d)的fan_in和fan_out。通过代码示例,解释了当权重张量维度为2时,对应线性层的fan_in等于输入通道数,fan_out等于输出通道数;而在卷积层中,fan_in和fan_out涉及卷积核大小,并通过计算权重张量特定维度的元素数量来确定。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文主要借助代码讲解Xavier和kaiming是如何借助_calculate_fan_in_and_fan_out函数来计算当前网络层的fan_in(输入神经元个数)和fan_out(输出神经元个数的),先针对Linear和Conv2d两种。

在这里插入图片描述

m_c = nn.Conv2d(16, 33, 3, stride=2)
m_l = nn.Linear(1, 10)
m_c.weight.size()
m_l.weight.size()

out:

torch.Size([33, 16, 3, 3])
torch.Size([10, 1])

注意看Linear weight的维度为2,而Conv2d的维度为4.
首先判断tensor的维度,如果是二维,则是Linear,

if dimensions == 2:  # Linear
        fan_in = tensor.size(1)
        fan_out = tensor.size(0)

此时:fan_in=in_channelsfan\_in = in\_channels fan_in=in_channels
fan_out=out_channelsfan\_out = out\_channels fan_out=out

<think>嗯,用户遇到了一个ValueError,说是不支持hardswish这个非线性函数的初始化。我需要仔细看看他们的代码,找出问题所在。 首先,用户定义了一个MLP类,里面的_init_weights方法负责初始化权重。在检查每个模块时,如果是nn.Linear层,并且该层有一个act属性,那么会根据不同的激活函数类型应用不同的初始化方法。对于nn.Hardswish的情况,他们使用了nn.init.kaiming_normal_,并设置了nonlinearity=&#39;hardswish&#39;。但是PyTorchKaiming初始化可能不支持这个参数,导致错误。 我记得在PyTorch的版本更新中,某些激活函数的支持是逐步添加的。比如,可能在较早的版本中,Kaiming初始化并不支持hardswish这个非线性参数。需要确认用户使用的PyTorch版本是否过旧。假设用户用的是1.6之前的版本,那么确实会出现这个问题,因为Hardswish是在1.6版本之后才被引入的。 解决方法有几个方面。首先,建议用户升级PyTorch到1.6或更高版本,这样官方应该已经支持了hardswish的初始化。如果升级不可行,可能需要修改初始化方法,比如改用其他支持的nonlinearity参数,比如&#39;relu&#39;,或者参考PyTorch文档中支持的选项。另外,用户也可以手动计算gain值,使用xavier初始化代替,避免依赖Kaiming的nonlinearity参数。 另外,用户的代码里在初始化时使用了不同的方法,比如XavierKaiming,这取决于激活函数类型。对于自定义的GaussSwish,用了Xavier,这可能没问题。但Hardswish的初始化可能需要特定的处理。如果用户不能升级PyTorch,可能需要调整初始化策略,或者自定义初始化方式。 总结下来,问题出在旧版PyTorch不支持hardswish作为Kaiming初始化的参数。解决方案包括升级PyTorch、改用其他初始化方法,或者手动调整参数。需要将这些可能性清晰地解释给用户,并给出具体的代码修改建议。</think>这个错误是因为在PyTorch旧版本中,`kaiming_normal_`初始化方法不支持`hardswish`非线性参数。让我们分步骤分析并解决这个问题: --- ### **错误原因分析** 1. **版本兼容性问题**: - `nn.Hardswish`在PyTorch 1.6+ 才被正式支持 - `kaiming_normal_`的`nonlinearity`参数在旧版本中可能未包含`hardswish`选项 2. **代码逻辑问题**: - 当检测到`nn.Hardswish`激活函数时,尝试使用`nonlinearity=&#39;hardswish&#39;`会导致参数不匹配 --- ### **解决方案** #### **方法一:升级PyTorch(推荐)** ```bash # 升级到最新稳定版(需>=1.6) pip install torch --upgrade ``` #### **方法二:修改初始化策略** ```python elif isinstance(m.act, nn.Hardswish): # 改用PyTorch支持的初始化参数 nn.init.kaiming_normal_(m.weight, mode=&#39;fan_in&#39;, nonlinearity=&#39;relu&#39;) # 临时替代方案 # 或使用Xavier初始化兼容方案 nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain(&#39;relu&#39;) * 1.7) ``` --- ### **技术原理说明** 1. **初始化与激活函数的关系**: - Xavier初始化:适合Sigmoid/Tanh等饱激活函数,数学表达式:$W \sim \mathcal{U}(-\sqrt{\frac{6}{n_{in}+n_{out}}}, \sqrt{\frac{6}{n_{in}+n_{out}}})$ - Kaiming初始化:更适合ReLU族非饱激活函数,数学表达式:$W \sim \mathcal{N}(0, \sqrt{\frac{2}{n_{in}}})$ 2. **Hardswish特性**: - 定义:$f(x) = x \cdot \frac{\text{ReLU6}(x+3)}{6}$ - 与ReLU的梯度特性相似,因此可暂时借用ReLU的初始化参数 --- ### **改进后的完整代码** ```python class MLP(nn.Module): def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): if hasattr(m, &#39;act&#39;): if isinstance(m.act, nn.Tanh): nn.init.xavier_normal_(m.weight, gain=5/3) elif isinstance(m.act, nn.Hardswish): # 兼容旧版本方案 nn.init.kaiming_normal_(m.weight, mode=&#39;fan_in&#39;, nonlinearity=&#39;relu&#39;) # 或使用修正后的Xavier # nn.init.xavier_normal_(m.weight, gain=1.7) elif isinstance(m.act, GaussSwish): nn.init.xavier_normal_(m.weight, gain=1.676) else: nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) ``` --- ### **验证建议** 1. 打印PyTorch版本确认: ```python import torch print(torch.__version__) # 需要>=1.6.0 ``` 2. 测试初始化效果: ```python layer = nn.Linear(10, 10) layer.act = nn.Hardswish() MLP()._init_weights() # 应不再报错 ``` --- ### **扩展知识** 对于自定义激活函数(如`GaussSwish`),建议通过实验确定最佳增益值: 1. 增益计算公式:$gain = \sqrt{\frac{2}{\text{Var}(f&#39;(x))}}$ 2. 可通过蒙特卡洛采样近似计算: ```python x = torch.randn(10000) output = custom_act(x) gain = 1 / torch.std(output) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值