pytorch中的线性模块的实现如下,在init函数中定义weight值和bias值。
class Linear(Module):
__constants__ = ['bias', 'in_features', 'out_features']
def __init__(self, in_features, out_features, bias=True):
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def forward(self, input):
return F.linear(input, self.weight, self.bias)
所以若要对linear子模块的参数进行初始化,利用如下策略可以对单个linear子模块进行参数初始化。
import torch.nn as nn
from torch.nn import init
from collections import OrderedDict
net = nn.Sequential(OrderedDict([
('linear', nn.Linear(num_inputs, 1))

最低0.47元/天 解锁文章
2619

被折叠的 条评论
为什么被折叠?



