0. 导入用到的库
import torch
import torch.nn as nn
from collections import OrderedDict
1. 继承Module构造和访问模型
Module
类是
torch.nn
模块里提供的⼀个模型构造类,是所有神经⽹络模块的基类,我们可以继承它来定义
我们想要的模型。
只需重写 forward方法
class MLP(nn.Module):
def __init__(self, **kwargs):
super(MLP,self).__init__(**kwargs)
self.hidden = nn.Linear(784, 256)
self.act = nn.ReLU()
self.output = nn.Linear(256, 10)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output
查看模型
X = torch.rand(2, 784)
net = MLP()
print(net)
net(X)
输出