今天下午自己在哪里倒腾着,不小心写了个Module出来,当然没有写具体的运算。吓写的,但可以从这个不能运行的程序中,学会写模型应该怎么写
class mylayer(torch.nn.Module):
def __init__(self, inn=3, out=4):
super(mylayer, self).__init__()
self.inn = inn
self.out = out
self.weight = torch.nn.Parameter(torch.Tensor(3, 4))
self.bias = torch.nn.Parameter(torch.Tensor(4))
def forward(self,x):
return 2*x
class tt(torch.nn.Module):
def __init__(self):
super(tt, self).__init__()
self.linear = torch.nn.Linear(3,4)
self.mylayer = mylayer(inn=4,out=5)
def forward(self,x):
out = self.linear(x)
out = self.mylayer(out)
return out
aa = tt()
print(aa)
tt(
(linear): Linear(in_features=3, out_features=4, bias=True)
(mylayer): mylayer()
)
aa.state_dict(