1. load()函数
torch.load()
是 PyTorch 中用于从文件中加载序列化对象的函数。它可以用于加载模型、张量、字典等 PyTorch 对象。torch.load()
的一般用法如下:
torch.load(f, map_location=None, pickle_module=<module 'pickle' from '...'>)
f
:要加载的文件的路径(字符串)或文件对象(通常是打开的文件)。map_location
:一个可选参数,用于指定在加载时将张量映射到的设备。默认为None
,表示张量将被加载到它们在保存时所在的设备上。pickle_module
:一个可选参数,用于指定要用于反序列化的 Pickle 模块。默认为 Python 内置的pickle
模块。
以下是一些示例用法:
import torch
# 示例1:加载模型
model = torch.load('model.pth')
# 示例2:加载模型并映射到CPU
model_cpu = torch.load('model.pth', map_location=torch.device('cpu'))
# 示例3:加载字典
dictionary = torch.load('dictionary.pkl')
# 示例4:加载张量
tensor = torch.load('tensor.pth')
请注意,torch.load()
函数的行为取决于文件中保存的内容。如果是模型,加载后的对象可以直接用于推理或继续训练。如果是张量或字典,你将获得相应的 Python 对象。
在使用 torch.load()
时,确保文件路径或对象正确,否则可能会导致加载失败。此外,确保在加载时设置的参数(例如设备映射)符合你的预期。
2. eval()函数
在 PyTorch 中,eval()
是一个模型对象(如 nn.Module
)上的方法,用于将模型切换到评估(evaluation)模式。这个方法主要用于在推理或验证阶段使用模型。
当调用 eval()
方法时,模型会切换为评估模式,这会影响一些层的行为。主要的变化包括:
-
Batch Normalization 和 Dropout 层的行为变化:
- 在训练过程中,Batch Normalization 层和 Dropout 层通常是启用的,以便进行模型训练时的归一化和随机失活。在评估模式中,这两个层通常会切换为禁用状态,以确保在推理时不引入不确定性。
-
梯度计算的关闭:
eval()
方法通常会关闭模型的梯度计算。这意味着在评估模式下,模型的参数将不再接受梯度更新,这有助于节省内存并提高推理速度。
示例用法如下:
import torch
import torchvision.models as models
# 创建一个模型(以ResNet为例)
model = models.resnet18()
# 切换为评估模式
model.eval()
# 使用模型进行推理或验证
with torch.no_grad():
output = model(torch.randn(1, 3, 224, 224))
在上述示例中,model.eval()
将模型切换为评估模式,然后可以使用 model
进行推理。在推理阶段,通常使用 torch.no_grad()
上下文管理器来禁用梯度计算,以减少内存占用。
总之,eval()
方法在模型切换到评估模式时会修改一些层的行为,使其适用于推理和验证任务。在训练和评估阶段交替进行时,使用 train()
方法切换回训练模式。