Pytorch搭建神经网络组件以及动态搭建详解

网络结构搭建组建

搭建参数层:class torch.nn.Module
定义的是一个类,通常用来定义具有参数的layer,比如卷积层,全连接层,会自动提取可学习参数nn.Parameter。但是dropout层和batch_norm层也用此模块来实现,因为可以通过方法model.eval来判别是训练模式还是测试模式。在这里定义无参数层也是可以的,但是会增加微小的调用开销,但是可以使得网络的层次更清晰。
搭建无参数层:torch.nn.functional
定义了一个函数,用来实现特定的功能。常用来定义无学习参数的layers,比如激励函数,pooling层等。会发现这里面其实也会有卷积函数, 上面用类torch.nn.Conv2d定义的卷积层实质就是调用的nn.functional的卷积函数。

搭建block:class torch.nn.Sequential( * args)
一个时序容器。nn.Modules 会以他们传入的顺序被添加到容器中。当然,也可以传入一个OrderedDict,相当于搭建一个block。有以下3种方法搭建nn.Sequential

1)nn.Sequential()对象.add_module(层名,层class的实例)
net1 = nn.Sequential()
net1.add_module('conv', nn.Conv2d(3, 3, 3))
net1.add_module('batchnorm', nn.BatchNorm2d(3))
net1.add_module('activation_layer', nn.ReLU())
--------------------- ------------2)nn.Sequential(*多个层class的实例)
model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )
--------------------------------------------3)nn.Sequential(OrderedDict([*多个(层名,层class的实例)]))
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

动态搭建DNN

模块列表class torch.nn.ModuleList(modules=None)
实质将多个nn.Modules保存到一个list中,ModuleList中包含的所有modules必须已经被pytorch正确的注册。ModuleList可以像一般的Python list一样被int下标索引,且具有append和extend方法。
此外,ModuleList可以通过迭代依次的模块,从而相当于搭建了一个DNN。

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.dnn = nn.ModuleList(
            [nn.Linear(10, 10) for i in range(2)])

    def forward(self, x):
        for layer in self.dnn:
            x = layer(x)
        return x

net = MyModule()
net.dnn.append(nn.ReLU(inplace=False)) # 动态修改
print(net.dnn)
print(net(torch.Tensor(np.arange(10))))

输出:

ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): ReLU()
)
tensor([ 0.7821,  0.0000,  0.0000,  0.8238,  1.9208,  0.0000,  3.8621,
         4.2936,  4.8779,  0.0000])

修改中间的模块:

net = MyModule()
net.dnn = net.dnn[0: 2].append\
    (nn.ReLU(inplace=False)).extend(net.dnn[2: ])
print(net.dnn)

输出:

ModuleList(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): Linear(in_features=10, out_features=10, bias=True)
  (2): ReLU()
  (3): Linear(in_features=10, out_features=10, bias=True)
  (4): Linear(in_features=10, out_features=10, bias=True)
)

可以插入nn.Sequential对象

block = nn.Sequential(nn.ReLU())
net = MyModule()
net.dnn = net.dnn[0: 2].append(block).extend(net.dnn[2: ])
已标记关键词 清除标记
相关推荐
©️2020 CSDN 皮肤主题: 编程工作室 设计师:CSDN官方博客 返回首页