(1)Flatten类
class Flatten(nn.Module):
def __init__(self):
super(Flatten,self).__init__()
def forward(self,input):
return input.view(input.size(0),-1)
Flatten()展平操作,进入全连接层之前使用
类才能写进nn.Sequential
nn.ReLU----class
F.relu----function
(2)自定义Linear层
class MyLinear(nn.Module):
def __init__(self,inp,outp):
super(MyLinear,self).__init__()
self.w = nn.Parameter(torch.randn(outp,inp))
self.b = nn.Parameter(torch.randn(outp))
def forward(self,x):
x = x @ self.w.t() + self.b
return x
torch.tensor不会加到nn.parameter
nn.Parameter用于包装tensor,tensor会自动加到nn.parameter,自动被优化nn.Parameter自动将requires_grad设置为true