Pytorch 模型参数管理
1. 参数初始化
- 使用
self.modules
,声明网络时初始化并加载权重
pytorch
模型应是 nn.Module
的子类
self.modules
: nn.Module
类中的一个方法, 返回该网络中的所有 modules
可以利用 self.modules
来对网络进行初始化。
class Network(nn.Module):
def __init__(self):
supe().__init__()
self.Conv2d = nn.Conv2d(3, 10)
sefl.bn = nn.BatchNorm2d(10)
self.relu = nn.ReLU()
self._init_weight() #在初始化网络时, 会执行该函数,然后初始化网络中的每个module
def forward(self, x):
x = self.Conv2d(x)
x = self.bn(x)
return self.relu(x)
def _init_weight(self):
for m in self.modules() #继承nn.Module的方法
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
- 先定义网络,后加载权重, 使用
net.apply()
import torch.nn a