深度学习技术栈 —— 如何保存权重文件?
一、权重文件的格式
权重文件是指训练好的模型参数文件,不同的深度学习框架和模型可能使用不同的权重文件格式。以下是一些常见的权重文件格式:
PyTorch
的模型格式:.pt
文件。
Darknet
的模型格式:.weight
文件。
TensorFlow
的模型格式:.ckpt
文件。
一、参考文章或视频链接 |
---|
[1] Navigating Model Weight File Formats: .safetensors, .bin, .pt, HDF5, and Beyond |
---------------------------------------------------------------------------------------------------------------------------------------------------------------------
二、保存权重文件
没有人想让自己的辛苦白费,计算好的权重文件就应该保存下来,不仅方便自己,也方便他人。
# Define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
# 保存权重的方法torch.save
torch.save(model.state_dict(), PATH)
二、参考文章或视频链接 |
---|
[1] SAVING AND LOADING MODELS - Pytorch |
[2] MODELS AND PRE-TRAINED WEIGHTS - Pytorch |
[3] Introducing TorchVision’s New Multi-Weight Support API - Pytorch |
[4] PyTorch Model Eval + Examples |
三、加载权重文件
3.1 nn.module.eval()方法
在 PyTorch 中,nn.Module是一个非常重要的类,用于实现各种神经网络层和模型。在使用 nn.Module进行训练和推理时,有时需要将模型设置为评估模式,这可以通过调用eval()方法实现。在评估模式下,PyTorch 中的某些层和操作会发生一些变化,例如 Dropout 和 BatchNormalization 层会被禁用,因为它们在训练和推理时的行为是不同的。此外,在eval()下,模型不会进行梯度计算(这也是预训练的意义所在),这可以提高推理速度并减少内存使用。
使用eval()方法将模型设置为评估模式非常简单,只需要在模型实例上调用该方法即可,调用eval()后,既然不会进行梯度计算,那自然也不会更新权重了,要不然加载好了一个预训练好的权重模型,又被调整了那不是哭死。
3.1 参考文章或视频链接 |
---|
[1] 《model.eval 至关重要!!!!model.eval()是否开启 BN 和 Dropout 的不同》 - CSDN |
[2] What’s the meaning of function eval() in torch.nn module - stackoverflow self.training=False |