import torch.nn as nn import torch.nn.functional as F def initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight.data, 0, 0.01) m.bias.data.zero_() class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(2, 10) self.fc2 = nn.Linear(10, 10) self.fc3 = nn.Linear(10, 1) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model = Net() initialize_weights(model) for layer in model.modules(): if isinstance(layer, nn.Linear): print('weight = {}'.format(layer.weight)) print('bias = {}'.format(layer.bias))
pytorch模型的数据初始化代码
最新推荐文章于 2024-06-09 10:36:42 发布