import torch.nn as nn
def init_weights(model):
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight.data,
mode='fan_out',
nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
12-17
831
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
“相关推荐”对你有帮助么?
-
非常没帮助
-
没帮助
-
一般
-
有帮助
-
非常有帮助
提交