torch.nn.Module
是 PyTorch 中一个重要的基类,用于构建神经网络模型。它提供了一种方便的方式来组织和管理模型参数、定义前向传播等功能。继承自 torch.nn.Module
的类可以被视为一个可训练的参数集合,可以包含其他模块,从而形成层次化的模型结构。
文章目录
一、关键功能和属性
1.1 参数管理
torch.nn.Module
可以追踪并管理所有注册的参数。通过 parameters()
方法,可以方便地获取模型中的所有参数。
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 5)
model = MyModel()
print(list(model.parameters()))
1.2 子模块管理
通过将其他 torch.nn.Module
的实例注册为当前模块的属性,可以形成层次化的模型结构。这使得模型可以以更模块化的方式进行定义。
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
model = MyModel()
1.3 前向传播定义
在继承 torch.nn.Module
的子类中,需要实现 forward
方法来定义模型的前向传播过程。
import torch.nn as nn
import torch
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = MyModel()
input_data = torch.randn(3, 10)
output = model(input_data)
1.4 模型保存和加载
模型可以方便地保存到文件并在需要时加载。这是通过 torch.save
和 torch.load
函数来实现的。
torch.save(model.state_dict(), 'my_model.pth')
loaded_model = MyModel()
loaded_model.load_state_dict(torch.load('my_model.pth'))
1.5 模型训练
由于继承了 torch.nn.Module
,模型可以使用 PyTorch 的优化器进行训练。
import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
# 在训练循环中使用 optimizer 和 criterion
二、使用案例
2.1 定义模型
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
2.2 定义不更新的参数
self.register_buffer
定义一组参数,参数不会变化,只可人为地改变它们的值),但是保存模型时,该组参数又作为模型参数不可或缺的一部分被保存。
import torch
import torch.nn as nn
from collections import OrderedDict
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# (1)常见定义模型时的操作
self.param_nn = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(1, 1, 3, bias=False)),
('fc', nn.Linear(1, 2, bias=False))
]))
# (2)使用register_buffer()定义一组参数
self.register_buffer('param_buf', torch.randn(1, 2))
# (3)使用形式类似的register_parameter()定义一组参数
self.register_parameter('param_reg', nn.Parameter(torch.randn(1, 2)))
# (4)按照类的属性形式定义一组变量
self.param_attr = torch.randn(1, 2)
def forward(self, x):
return x
net = Model()
三、内置函数
3.1 add_module
将子模块添加到当前模块。
3.2 apply
对当前模块及其所有子模块递归地应用函数 fn
。
@torch.no_grad()
def init_weights(m):
print(m)
if type(m) == nn.Linear:
m.weight.fill_(1.0)
print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)
3.3 bfloat16
将所有浮点型参数和缓冲区(buffers)的数据类型转换为bfloat16。
3.4 **buffers
返回一个模块缓冲区的迭代器。
- 参数:
recurse
(可选)表示是否递归地遍历所有子模块的缓冲区。 - 用法:
model.buffers()
返回模块及其所有子模块的缓冲区。
3.5 children
返回一个模块的直接子模块的迭代器。
3.6 compile
使用 torch.compile()
编译模块的前向传播。使用 TorchDynamo
和指定的后端来优化给定的模型或函数。
3.7 cpu
将模型的所有参数和缓冲区移动到 CPU 上。
3.8 cuda
将模型的所有参数和缓冲区移动到 GPU 上。
3.9 double
将所有浮点数参数和缓冲区转换为双精度数据类型(double)。
3.10 eval
使用 eval()
将模块设置为评估(evaluation)模式。在评估模式下,模型中的某些层(例如,Dropout)的行为可能会有所不同,通常用于模型推断阶段。