持续更新中......
本文是在学习过程中的多篇文章整理,原文连接如下:
1、使用nn.Module模块的好处的原文链接:https://medium.com/dejunhuang/learning-day-22-what-is-nn-module-in-pytorch-ecf8400f411a
2、Pytorch官网关于nn.Module模块的链接:Module — PyTorch 2.3 documentation
一、nn.Module 是什么
nn.Module 类的源码链接:torch.nn.modules.module — PyTorch 2.3 documentation
在PyTorch中,`nn.Module`是一个非常核心的类,它是构建神经网络模型的基础模块,所有自定义的神经网络组件都应该继承自`nn.Module`类,在这个类中提供了很多好的方法,可以拿来直接使用。这个类提供了以下几个关键功能:
1. 组织结构:允许你以层次化的方式组织神经网络层(例如全连接层`Linear`、卷积层`Conv2d`、激活函数等)。这意味着你可以定义包含多个子模块的复杂模型,并轻松管理它们之间的连接。
2. 参数管理:自动跟踪和管理模型中的所有可学习的参数(比如权重和偏置)。当你调用`.parameters()`方法时,它会收集模型及其子模块的所有参数,这对于梯度更新和优化器非常重要。
3. 前向传播:你需要重写`forward`方法来定义模型的前向传播行为。在这个方法中,你可以使用PyTorch的tensor操作和已定义的子模块来处理输入数据并产生输出。
4. 设备分配:`nn.Module`使得在CPU或GPU上运行模型变得简单。当你将模型或其输入移动到特定设备(如`.to(device)`)时,模型中的所有参数和缓冲区也会被移动到该设备上。
5. 序列化与加载模型:`nn.Module`支持模型的保存与加载功能,使得模型的持久化和迁移变得容易。
示例代码展示了一个简单的`nn.Module`子类定义:
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self, input_size, output_size):
super(SimpleNet, self).__init__()
self.fc = nn.Linear(input_size, output_size)
def forward(self, x):
return self.fc(x)
# 实例化并使用模型
model = SimpleNet(100, 10)
output = model(torch.randn(1, 100))
在这个例子中,`SimpleNet`类继承自`nn.Module`,定义了一个具有单个全连接层的简单神经网络。通过重写`__init__`和`forward`方法,我们可以自定义模型的结构和前向传播逻辑。
二、使用 nn.Module 的好处
1、nn.Module 可以作为基础类被其他模型类继承
import torch
import torch.nn as nn
class BasicNet(nn.Module):
def __init__(self):
super(BasicNet, self).__init__()
self.net = nn.Linear(4, 3)
def forward(self, x):
return self.net(x)
2、每一层实际上都是 nn.Module(nn.Linear、nn.BatchNorm2d、nn.Conv2d、......)
3、可以嵌套,例如 nn.Module 可以位于另一个 nn.Module 内
# BasicNet()为上面定义的模型类,该类继承nn.Module
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = nn.Sequential(BasicNet(), nn.ReLU(), nn.Linear(3, 2))
def forward(self, x):
return self.net(x)
4、嵌入层,如 Linear、ReLU、sigmoid、Conv2d、Dropout……
5、model.parameters()/model.named_parameters() 包含所有权重和偏差(以及名称)。前者可以传递给优化器进行反向传播
net = Net()
print(dict(net.named_parameters()).items())
6、可以轻松转移到 GPU
device = torch.device('cuda')
net = Net()
net.to(device) # return net, so it is in-place
7、通过保存权重和偏差(称为状态)轻松保存和加载模型
net.load_state_dict(torch.load('ckpt.mdl'))
torch.save(net.state_dict(), 'ckpt.mdl')
8、轻松在训练和评估模式之间切换
net.train()
net.eval()
9、实现将 forward() 包含在 nn.Sequential() 中的函数,得到更简化的实现
# updated: self implement here but Pytorch 1.8.1 included it already
# if not, the below nn.Sequential() will need to break into two parts before and after flattening the layer by other methods
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
return input.reshape(input.shape[0], -1)
class TestNet(nn.Module):
def __init__(self):
super(TestNet, self).__init__()
self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),
nn.MaxPool2d(2,2),
Flatten(),
nn.Linear(1*14,14, 10))
def forward(self, x):
return self.net(x)
10、使用 nn.Parameter(tensor) 加载参数,以便将 tensor 包含在 model.parameters() 中,并输入到优化器
class MyLinear(nn.Module):
def __init__(self, input, output):
super(MyLinear, self).__init__()
# use nn.Parameter() instead of directly creating tensors
self.w = nn.Parameter(torch.rand(output, input))
self.b = nn.Parameter(torch.randn(output))
def forward(self, x):
x = x @ self.w.t() + self.b
return x
三、nn.Module类中的常用方法
1、Module.forward()
forward(*input)
定义每次调用时执行的计算。应被所有子类覆盖。
尽管需要在此函数中定义前向传递的配方,但是之后应该调用 Module 实例而不是这个,因为前者负责运行已注册的钩子,而后者会默默地忽略它们。