nn.Module
用nn.Module实现全连接层
import torch as t
from torch import nn
from torch.autograd import Variable as V
class Linear(nn.Module):
def __init__(self, in_features, out_features):
super(Linear, self).__init__()
self.w = nn.Parameter(t.randn(in_features, out_features))
print(self.w.shape)
self.b = nn.Parameter(t.randn(out_features))
def forward(self, x):
x = x.mm(self.w)
x = x + self.b.expand_as(x)
return x
layer = Linear(4, 3)
input = V(t.randn(2, 4))
print(input.shape)
output = layer(input)
print(output)
for name, parameter in layer.named_parameters():
print(name, parameter)
- 自定义层必须继承nn.Module,并在构造函数中调用nn.Module的构造函数,即super(Linear, self)__init__(),可利用前面自定义的层作为当前m