深度学习技术栈 —— Pytorch中保存与加载权重文件

一、权重文件的格式

权重文件是指训练好的模型参数文件,不同的深度学习框架和模型可能使用不同的权重文件格式。以下是一些常见的权重文件格式:
PyTorch的模型格式:.pt文件。
Darknet的模型格式:.weight文件。
TensorFlow的模型格式:.ckpt文件。

一、参考文章或视频链接
[1] Navigating Model Weight File Formats: .safetensors, .bin, .pt, HDF5, and Beyond

---------------------------------------------------------------------------------------------------------------------------------------------------------------------

二、保存权重文件

没有人想让自己的辛苦白费,计算好的权重文件就应该保存下来,不仅方便自己,也方便他人。

# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

# 保存权重的方法torch.save
torch.save(model.state_dict(), PATH)
二、参考文章或视频链接
[1] SAVING AND LOADING MODELS - Pytorch
[2] MODELS AND PRE-TRAINED WEIGHTS - Pytorch
[3] Introducing TorchVision’s New Multi-Weight Support API - Pytorch
[4] PyTorch Model Eval + Examples

三、加载权重文件

3.1 nn.module.eval()方法

在 PyTorch 中,nn.Module是一个非常重要的类,用于实现各种神经网络层和模型。在使用 nn.Module进行训练和推理时,有时需要将模型设置为评估模式,这可以通过调用eval()方法实现。在评估模式下,PyTorch 中的某些层和操作会发生一些变化,例如 Dropout 和 BatchNormalization 层会被禁用,因为它们在训练和推理时的行为是不同的。此外,在eval()下,模型不会进行梯度计算(这也是预训练的意义所在),这可以提高推理速度并减少内存使用。
使用eval()方法将模型设置为评估模式非常简单,只需要在模型实例上调用该方法即可,调用eval()后,既然不会进行梯度计算,那自然也不会更新权重了,要不然加载好了一个预训练好的权重模型,又被调整了那不是哭死

3.1 参考文章或视频链接
[1] 《model.eval 至关重要!!!!model.eval()是否开启 BN 和 Dropout 的不同》 - CSDN
[2] What’s the meaning of function eval() in torch.nn module - stackoverflow self.training=False
  • 25
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值