nn.Module
是在pytorch使用非常广泛的类,搭建网络基本都需要用到这个。
当我们搭建自己的网络时,可以继承官方写好的nn.Module
模块,为什么要用这个呢?好处如下:
nn.Module作用
1.可以提供一些现成的基本模块比如:
Linear、ReLU、Sigmoid、Conv2d、Dropout
不用自己一个一个的写这些函数了,这也是为什么我们用框架的原因之一吧。
2. 容器
比如我们经常用到的 nn.Sequential()
,顾名思义,将网络模块封装在一个容器中,可以方面网络搭建
如下面一个例子:
class TestNet(nn.Module):
def __init__(self):
super(TestNet, self).__init__()
self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Linear(1*14*14, 10))
def forward(self, x):
return self.net(x)
3.参数管理
参数名字可以自动生成(想想如果自己去命名,百万参数的网络没法搭建),然后这些参数都可以传到优化器里面去优化
4. 所有modules的节点 孩子节点都是直系的
class BasicNet(nn.Module):
def __init__(self):
super(BasicNet, self).__init__(