PyTorch 14. module类

module类

pytorch中其实一般没有特别明显的Layer和Module的区别,不管是自定义层、自定义块、自定义模型,都是通过继承Module类完成的,其实Sequential类也是继承Module类。

module类的定义:

class Module(object):
	def __init__(self):
	def forward(self, *input):
	
	def add_module(self,name,module):
	def cuda(self,device=None):
	def cpu(self):
	def __call__(self, *input, **kwargs):
	def parameters(self, recurse=True):
	def named_parameters(self, prefix='', recurse=True):
	def children(self):
	def named_children(self):
	def modules(self):
	def named_modules(self,memo=None,prefix=''):
	def train(self,mode=True):
	def eval(self):
	def zero_grad(self):
	def __repr__(self):
	def __dir__(self):

我们在定义自己的网络的时候,需要继承nn.Module类,并重新实现构造函数__init__构造函数和forward这两个方法。
注意技巧
(1)一般把网络中具有可学习参数的层(如全连接层,卷积层等)放在构造函数__init__()中,当然我也可以把不具有参数的层也放在里面
(2)一般把不具有可学习参数的层(如ReLU, dropout, BatchNormalnation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替
(3)forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心

注意:只要在module中__init__定义的层,即使forward中没有使用到,模型保存时,也会保存该网络层,因此如果不需要的东西就不要放到__init__

torch.nn.Module类的多种实现

方法1:通过Sequential层来包装层,即将几个层包装在一起作为一个大的层

import torch.nn as nn
from collections import OrderedDict
class MyNet(nn.Module):
	def __init__(self):
		super(MyNet, self).__init__()
		self.conv_block = nn.Sequential(
			nn.Conv2d(3, 32, 3,1,1),
			nn.ReLU(),
			nn.MaxPool2d(2)
		)
		self.dense_block = nn.Sequential(
			nn.Linear(32*3*3, 128),
			nn.ReLU(),
			nn.Linear(128,10)
		)
	def forward(self,x):
		conv_out = self.conv_block(x)
		res = conv_out.view(conv_out.size(0),-1)
		out = self.dense_block(res)
		return out

model = MyNet()
print(model)

这里在每一个包装块里面,各个层是没有名称的,默认按照0,1,2,3,4排序。
方法2

import torch.nn as nn
from collections import OrderedDict
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv_block = nn.Sequential(
            OrderedDict(
                [
                    ("conv1", nn.Conv2d(3, 32, 3, 1, 1)),
                    ("relu1", nn.ReLU()),
                    ("pool", nn.MaxPool2d(2))
                ]
            ))
 
        self.dense_block = nn.Sequential(
            OrderedDict([
                ("dense1", nn.Linear(32 * 3 * 3, 128)),
                ("relu2", nn.ReLU()),
                ("dense2", nn.Linear(128, 10))
            ])
        )
 
    def forward(self, x):
        conv_out = self.conv_block(x)
        res = conv_out.view(conv_out.size(0), -1)
        out = self.dense_block(res)
        return out
 
model = MyNet()
print(model)

方法3

import torch.nn as nn
from collections import OrderedDict
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv_block=torch.nn.Sequential()
        self.conv_block.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
        self.conv_block.add_module("relu1",torch.nn.ReLU())
        self.conv_block.add_module("pool1",torch.nn.MaxPool2d(2))
 
        self.dense_block = torch.nn.Sequential()
        self.dense_block.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
        self.dense_block.add_module("relu2",torch.nn.ReLU())
        self.dense_block.add_module("dense2",torch.nn.Linear(128, 10))
 
    def forward(self, x):
        conv_out = self.conv_block(x)
        res = conv_out.view(conv_out.size(0), -1)
        out = self.dense_block(res)
        return out
 
model = MyNet()
print(model)

上面的方法2和方法3,在每一个包装块中,每个层都是有名称的。
特别注意:Sequential类虽然继承自Module类,二者有相似部分,但是也有很多不同的部分,集中体现在:

Sequential类实现了整数索引,故而可以使用model[index]这样的方式获取一个层,但是Module类并没有实现整数索引,不能够通过整数索引来获得层,但是它提供了几个主要方法,如下:
具体内容可以查看我的这一篇博客

children和modules之间的差异性

注意pytorch里面不管是模型、层、激活函数、损失函数都可以当成是Module的拓展,所以modules和named_modules会层层迭代,由浅入深,将每个自定义块block、然后block里面的每个层都当成是module来迭代,而children就比较直观,就表示所谓的“孩子”,所以没有层层深入

参考:
https://zhuanlan.zhihu.com/p/156127643

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值