PyTorch模型格式及使用

模型是神经网络训练优化后得到的成果,包含了神经网络骨架及学习得到的参数。

 一、常用的模型的格式

模型保存的格式多种多样,主要取决于所使用的深度学习框架和具体需求。以下是一些常见的模型保存格式:

1.1. TensorFlow 相关格式

  • .ckpt:TensorFlow的模型保存格式,通常包含四个文件,包括.meta(模型的结构信息)、.ckpt(权重和偏置等变量的值)、.index.data(二进制格式,保存具体的权重数据)。
  • SavedModel:TensorFlow的一种序列化格式,包括模型的结构、权重和配置信息,通常保存在一个目录下,包含多个文件。这种格式便于部署和在不同环境中使用。
  • .pb (Protocol Buffers):Google的数据序列化协议,有时也用于保存机器学习模型。

1.2. Keras 相关格式

  • .h5 或 .hdf5:Keras(一个基于TensorFlow的高级神经网络API)常用的模型保存格式,是一个单一的文件,包含了模型的权重、结构和配置信息。使用model.save('model.h5')即可保存整个模型。

1.3. PyTorch 相关格式

  • .pt 或 .pth:PyTorch框架使用这种格式来保存模型的权重。此外,PyTorch也支持保存整个模型为一个文件,格式为.pt。使用torch.save(model.state_dict(), 'model_weights.pth')可以保存模型的权重,而torch.save(model, 'model.pth')则可以保存整个模型。pt 文件保存整个 PyTorch 模型,而 .pth 文件只保存模型的参数
  • TorchScript 是 PyTorch 中的一个特性,它允许你将 PyTorch 模型转换为一种中间表示(Intermediate Representation, IR),这种表示可以被优化并独立于 Python 运行。TorchScript 旨在提高 PyTorch 模型的部署性能、可移植性和可扩展性,尤其是在生产环境中。

1.4. ONNX 格式

  • .onnx:ONNX(Open Neural Network Exchange)是一种用于表示神经网络模型的开放式标准格式,它允许在不同的深度学习框架之间交换模型。可以使用专门的库(如tf2onnxonnx-tensorflow)将模型转换为ONNX格式。

1.5. 其他格式

  • .json:虽然不常用于直接保存模型权重,但JSON格式可以用于保存模型的配置或结构信息,因为它易于阅读和写入,并且可以在不同的编程语言之间轻松共享。
  • .t7 或 .pkl:在某些情况下,特别是在早期版本的PyTorch或torch7中,可能会使用.t7文件来读取模型权重,而.pkl则是Python中常用的序列化格式,也可以用于保存模型。

1.6选择模型保存格式的考虑因素

在选择模型保存格式时,应考虑以下因素:

  • 跨平台兼容性:确保所选格式能够在不同的操作系统和环境中轻松加载和使用。
  • 可移植性:考虑是否需要将模型迁移到不同的深度学习框架中。
  • 性能:不同格式在加载和保存时可能有不同的性能表现。
  • 文件大小:对于大型模型,文件大小可能是一个重要的考虑因素。
  • 安全性:在某些情况下,可能需要考虑模型数据的安全性,选择加密或受保护的格式。

综上所述,模型保存的格式多种多样,应根据具体需求和所使用的深度学习框架来选择最合适的格式。

1.7 为什么要约定格式?

根据上一段的思路,我们知道模型重现的关键是模型结构/模型参数/数据集,那么我们提供或者希望别人提供这些信息,需要一个交流的规范,这样才不会1000个人给出1000种格式,而 .pt .pth .bin 以及 .onnx 就是约定的格式

torch.save: Saves a serialized object to disk. This function uses Python’s  pickle utility for serialization. Models,  tensors, and dictionaries of all kinds of objects can be saved using this function.

不同的后缀只是用于提示我们文件可能包含的内容,但是具体的内容需要看模型提供者编写的README.md才知道。而在使用torch.load()方法加载模型信息的时候,并不是根据文件的后缀进行的读取,而是根据文件的实际内容自动识别的,因此对于torch.load()方法而言,不管你把后缀改成是什么,只要文件是对的都可以读取

torch.load: Uses  pickle’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into (see  Saving & Loading Model Across Devices).

顺便提一下,“一切皆文件”的思维才是正确打开计算机世界的思维方式,文件后缀只作为提示作用,在Windows系统中也会用于提示系统默认如何打开或执行文件,除此之外,文件后缀不应该成为我们认识和了解文件阻碍。

1.8 格式汇总

下面是一个整理了 .pt.pth.bin、ONNX 和 TorchScript 等 PyTorch 模型文件格式的表格:

格式解释适用场景可对应的后缀
.pt 或 .pthPyTorch 的默认模型文件格式,用于保存和加载完整的 PyTorch 模型,包含模型的结构和参数等信息。需要保存和加载完整的 PyTorch 模型的场景,例如在训练中保存最佳的模型或在部署中加载训练好的模型。.pt 或 .pth
.bin一种通用的二进制格式,可以用于保存和加载各种类型的模型和数据。需要将 PyTorch 模型转换为通用的二进制格式的场景。.bin
ONNX一种通用的模型交换格式,可以用于将模型从一个深度学习框架转换到另一个深度学习框架或硬件平台。在 PyTorch 中,可以使用 torch.onnx.export 函数将 PyTorch 模型转换为 ONNX 格式。需要将 PyTorch 模型转换为其他深度学习框架或硬件平台可用的格式的场景。.onnx
TorchScriptPyTorch 提供的一种序列化和优化模型的方法,可以将 PyTorch 模型转换为一个序列化的程序,并使用 JIT 编译器对模型进行优化。在 PyTorch 中,可以使用 torch.jit.trace 或 torch.jit.script 函数将 PyTorch 模型转换为 TorchScript 格式。需要将 PyTorch 模型序列化和优化,并在没有 Python 环境的情况下运行模型的场景。.pt 或 .pth

二、状态字典(state_dict)

2.1 什么是状态字典

状态字典(state_dict)是深度学习框架(如PyTorch)中用于保存和加载模型参数的一种数据结构。它本质上是一个Python字典对象,将模型中的每一层(特别是那些具有可学习参数的层,如卷积层、线性层等)映射到其对应的参数张量(即权重和偏差)。

在深度学习框架(如PyTorch)中,状态字典(state_dict)主要包含了模型的参数(parameters),这些参数是模型在训练过程中学习到的权重(weights)和偏差(biases)。具体来说,state_dict是一个从字符串到张量(Tensor)的映射(即Python字典),其中字符串是参数的名称(这些名称通常是层次结构的,以反映它们在模型中的位置),而张量则是参数的实际值。

除了模型的参数之外,虽然state_dict本身不直接包含优化器的状态或训练时的其他信息(如当前迭代次数、学习率等),但你可以将这些信息也保存在一个单独的字典中,并使用torch.save()函数将它们与模型的state_dict一起保存,或者分别保存。

然而,在PyTorch中,优化器对象(如torch.optim.Optimizer的子类)也拥有它们自己的state_dict,这个state_dict包含了优化器在训练过程中需要的所有状态信息,如动量(对于基于动量的优化器)、学习率衰减参数等。因此,在保存模型时,通常也会保存优化器的state_dict,以便在加载模型时能够恢复训练状态。

总结一下,state_dict通常包含以下与模型相关的内容:

  1. 模型参数:包括模型的权重和偏差,这些是模型学习到的关键信息。

  2. (可选)优化器状态:虽然这不是模型state_dict直接包含的内容,但通常会与模型参数一起保存,以便恢复训练过程。

  3. (可选)其他训练信息:如训练时的epoch数、损失值等,这些信息可以单独保存,或者包含在模型的检查点(checkpoint)中,但通常不是state_dict的直接组成部分。

PyTorch 中,一个模型(torch.nn.Module)的可学习参数(也就是权重和偏置值)是包含在模型参数(model.parameters())中的,一个状态字典就是一个简单的 Python 的字典,其键值对是每个网络层和其对应的参数张量。模型的状态字典只包含带有可学习参数的网络层(比如卷积层、全连接层等)和注册的缓存(batchnorm的 running_mean)。优化器对象(torch.optim)同样也是有一个状态字典,包含的优化器状态的信息以及使用的超参数

由于状态字典也是 Python 的字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都很容易实现。

lass CNN_Net(nn.Module):
    def __init__(self):
        super(CNN_Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建模型
net = CNN_Net().to(device)

print(net)

#定义优化器和损失函数
criterion = nn.CrossEntropyLoss() # 交叉式损失函数
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 优化器

# Print model's state_dict
print("Model's state_dict:")
# print(net.state_dict())
for param_tensor in net.state_dict():
    print(param_tensor, "\t", net.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
print(optimizer.state_dict())
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

输出结果:
CNN_Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]

在PyTorch中,state_dict是一个非常关键的概念,它本质上是一个从每一层到其参数张量映射的字典对象。PyTorch中的模型(如神经网络)是由多个层(layers)组成的,而每一层都包含了一些需要学习的参数(如权重和偏置)。这些参数在训练过程中会被不断更新以优化模型性能。

state_dict是一个Python字典对象,它将每一层映射到其参数张量。注意,这里的“层”是广义的,它不仅仅指的是神经网络中的线性层或卷积层,还可以是优化器(optimizers)、批量归一化层(BatchNorm layers)等任何具有可学习参数的PyTorch模块。

2.2 主要用途

  1. 模型保存与加载:通过保存和加载state_dict,我们可以轻松地保存和加载模型的参数,而不需要重新训练模型。这对于模型部署、模型迁移和结果复现等场景非常有用。

  2. 模型微调(Fine-tuning):在迁移学习场景中,我们可能会加载一个预训练的模型,并修改其最后一层(或几层)以适应新的任务。这时,我们可以只加载预训练模型的state_dict,并更新或替换我们想要修改层的参数。

2.3 状态字典的保存和加载

在PyTorch中,通常使用torch.save()函数来保存state_dict,而使用load_state_dict()函数来加载state_dict。保存时,可以直接将state_dict作为参数传递给torch.save()函数,并指定保存的路径。加载时,则需要先加载保存的state_dict文件,并将其传递给模型的load_state_dict()函数。

假设我们有一个训练好的模型model,我们可以使用以下代码来保存其state_dict:

torch.save(model.state_dict(), 'model_weights.pth')

 然后,在需要加载这个模型参数的时候,我们可以使用以下代码:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

当需要为预测保存一个模型的时候,只需要保存训练模型的可学习参数即可。采用 torch.save() 来保存模型的状态字典的做法可以更方便加载模型,这也是推荐这种做法的原因。

通常会用 .pt 或者 .pth 后缀来保存模型。

2.4  保存模型参数转换为保存整个模型:

import torch
import torchvision.models as models

# 创建模型并加载模型参数
model = models.resnet18()
model.load_state_dict(torch.load('model_params.pth'))

# 保存整个模型
torch.save(model, 'whole_model_from_params.pth')

2.5 注意事项

  • 在进行预测之前,必须调用 model.eval() 方法来将 dropout 和 batch normalization 层设置为验证模型。否则,只会生成前后不一致的预测结果。
  • load_state_dict() 方法必须传入一个字典对象,而不是对象的保存路径,也就是说必须先反序列化字典对象,然后再调用该方法,也是例子中先采用 torch.load() ,而不是直接 model.load_state_dict(PATH)
  • 在加载state_dict之前,需要确保模型的架构与保存state_dict时使用的架构完全一致。否则,可能会因为参数不匹配而导致加载失败。
  • 如果只需要加载模型的部分参数(例如,在迁移学习中),可以在加载前对state_dict进行修改,只保留需要的参数部分。
  • state_dict只包含模型的参数(即权重和偏置),而不包含模型的架构信息。因此,在加载state_dict之前,必须确保已经定义了相应的模型架构。
  • 当你想要保存整个模型(包括架构和参数)时,可以使用torch.save(model, 'model.pth'),但这通常不推荐用于模型迁移,因为它可能包含与特定机器或PyTorch版本相关的信息。
  • 在处理多GPU训练时,你可能需要先将state_dict从特定的GPU设备转移到CPU上,然后再进行保存或加载。这可以通过调用.cpu()方法来实现。

建议使用保存模型参数,而不是保存整个模型

三、TorchScript格式

3.1 什么是 TorchScript?

TorchScript 是 Pytorch 中的一个功能,它允许我们将 Pytorch 模型编译成一个运行时环境无关的中间表示(Intermediate Representation)。通过将模型编译成 TorchScript 格式,我们可以在无需依赖 Python 的环境中使用模型进行推断,从而提高模型的性能和部署效率。

3.2 目的

pytorch模型转化成torchscript目的就是为了可以在c++环境中调用pytorch模型。

3.3 主要特性和优势

  1. 独立于 Python 运行:TorchScript 生成的模型可以脱离 Python 环境运行,这意味着它们可以在不支持 Python 的环境中部署,比如 C++、移动设备或 Web。

  2. 性能优化:通过 TorchScript,PyTorch 可以对模型进行静态分析,并应用多种优化策略,如常量折叠、死码消除和算子融合等,以提高模型的执行速度。

  3. 可移植性:TorchScript 支持将模型导出为多种格式,如 TorchScript 自身的 .pt 或 .ts 文件,以及 ONNX(Open Neural Network Exchange)格式,后者是一种开放的神经网络模型交换格式,支持多种框架和平台。

  4. 易于使用:TorchScript 提供了两种模式来转换和运行 PyTorch 模型:追踪(Tracing)模式和脚本(Scripting)模式。追踪模式适用于动态图执行路径相对固定的模型,而脚本模式则提供了更灵活的控制,允许开发者显式地指定哪些 Python 代码应该被转换成 TorchScript。

3.4 使用方法

共有两种方法将pytorch模型转成torch script ,一种是trace,另一种是script。一版在模型内部没有控制流存在的话(if,for循环),直接用trace方法就可以了。如果模型内部存在控制流,那就需要用到script方法了。

追踪模式(trace
通过 torch.jit.trace 函数来追踪模型的前向传播过程,并生成 TorchScript 模型。这种方式简单快捷,但可能无法处理控制流(如 if 语句)和动态输入的情况

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule,self).__init__()
        self.conv1 = nn.Conv2d(1,3,3)

    def forward(self,x):
        x = self.conv1(x)
        return x

model = MyModule()  # 实例化模型
trace_module = torch.jit.trace(model,torch.rand(1,1,224,224)) 
print(trace_module.code)  # 查看模型结构
output = trace_module (torch.ones(1, 3, 224, 224)) # 测试
print(output)
trace_modult('model.pt') # 模型保存

脚本模式(script:通过 @torch.jit.script 装饰器或 torch.jit.script 函数来显式地将 Python 代码转换为 TorchScript。这种方式提供了更高级的控制,可以处理复杂的控制流和动态输入。 

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule,self).__init__()
        self.conv1 = nn.Conv2d(1,3,3)
        self.conv2 = nn.Conv2d(2,3,3)

    def forward(self,x):
        b,c,h,w = x.shape
        if c ==1:
            x = self.conv1(x)
        else:
            x = self.conv2(x)
        return x

model = MyModule()

# 这样写会报错,因为有控制流
# trace_module = torch.jit.trace(model,torch.rand(1,1,224,224)) 

# 此时应该用script方法
script_module = torch.jit.script(model) 
print(script_module.code)
output = script_module(torch.rand(1,1,224,224))

3.5  生成步骤

步骤一:创建自定义 Pytorch 模型

首先,我们需要创建一个自定义的 Pytorch 模型。以下是一个简单的示例代码:

import torch
import torch.nn as nn

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

上述示例代码创建了一个名为 CustomModel 的自定义模型,包含了一个线性层。请根据实际需求定义并实现自己的 Pytorch 模型。

步骤二:保存 Pytorch 模型为.pth文件

在我们将自定义 Pytorch 模型转换为 TorchScript 格式之前,首先需要将模型保存为.pth文件。我们可以使用 Pytorch 提供的 torch.save 函数来保存模型。以下是一个示例代码:

model = CustomModel()

# 训练模型...

# 保存模型为.pth文件
torch.save(model.state_dict(), 'custom_model.pth')

上述示例代码将训练好的模型保存为 custom_model.pth 文件。

步骤三:加载.pth文件并转换为 TorchScript 格式

在我们进行模型转换之前,需要加载保存的.pth文件。以下是一个示例代码:

# 加载.pth文件
state_dict = torch.load('custom_model.pth')

model = CustomModel()
model.load_state_dict(state_dict)
model.eval()

上述示例代码首先加载了保存的.pth文件,然后创建了自定义模型的实例,并加载了模型参数。最后,通过调用 model.eval() 方法,将模型设置为评估模式

现在,我们已经加载了.pth文件并初始化了模型,接下来是将模型转换为 TorchScript 的步骤。以下是一个示例代码:

# 转换为 TorchScript 格式
script_model = torch.jit.trace(model, torch.randn(1, 10))

上述示例代码使用了 torch.jit.trace 函数将模型转换为 TorchScript 格式。其中,torch.randn(1, 10) 是输入示例,用于生成 TorchScript 图

步骤四:保存 TorchScript 模型为.pt文件

最后,我们可以使用 script_model 对象将 TorchScript 模型保存为.pt文件。以下是一个示例代码:

# 保存 TorchScript 模型为.pt文件
script_model.save('custom_model.pt')

上述示例代码将 TorchScript 模型保存为 custom_model.pt 文件。

示例代码总结

以下是上述示例代码的完整代码总结:

import torch
import torch.nn as nn

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

# 创建模型实例
model = CustomModel()

# 保存模型为.pth文件
torch.save(model.state_dict(), 'custom_model.pth')

# 加载.pth文件
state_dict = torch.load('custom_model.pth')

model = CustomModel()
model.load_state_dict(state_dict)
model.eval()

# 转换为 TorchScript 格式
script_model = torch.jit.trace(model, torch.randn(1, 10))

# 保存 TorchScript 模型为.pt文件
script_model.save('custom_model.pt')

以上示例代码将自定义 Pytorch 模型保存为.pth文件,并将其转换为 TorchScript 格式,并最终保存为.pt文件。

3.4 总结

本文介绍了将自定义的 Pytorch 模型转换为 TorchScript 格式的过程。通过将模型保存为.pth文件,加载.pth文件并转换为 TorchScript 格式,最终可以将模型保存为.pt文件。使用 TorchScript 可以提高模型的性能和部署效率,使我们能够在无需依赖 Python 的环境中使用模型进行推断。

四、通用的检查点:Checkpoint

保存和加载用于推理或恢复训练的通用检查点模型有助于从上次中断的地方继续训练。保存通用检查点时,您不仅要保存模型的 state_dict,还必须保存优化器的 state_dict,因为这包含在模型训练过程中更新的缓冲区和参数。您可能想要保存的其他项目包括您中断的时期、最新记录的训练损失、外部 torch.nn.Embedding 层等,具体取决于您自己的算法。

4.1 简介

要保存多个检查点,您必须将它们组织在字典中,并使用 torch.save() 序列化字典。PyTorch 的常见约定是使用 .tar 文件扩展名保存这些检查点。要加载项目,首先初始化模型和优化器,然后使用 torch.load() 在本地加载字典。从这里,您可以像预期的那样通过简单地查询字典轻松访问已保存的项目。在本指南中,我们将探讨如何保存和加载多个检查点。

4.2 生成步骤

  • 导入加载数据所需的所有库
import torch
import torch.nn as nn
import torch.optim as optim
  • 定义并初始化神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)
  • 初始化优化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  • 保存常规检查点

PyTorch本身并不直接提供一个“保存整个模型”的函数,因为模型的定义(即类的结构和层的关系)是Python代码,不是可以被直接保存的数据。但你可以通过保存模型参数、优化器状态、以及其他相关状态来实现类似的效果。

一种常见的做法是将模型参数、优化器状态以及其他信息(如当前epoch数)存储在一个字典中,然后使用torch.save()函数保存这个字典。这通常被称为“检查点”(checkpoint)。

保存的模型文件通常是以 .tar 作为后缀名。

# Additional information
EPOCH = 5
PATH = "your_checkpoint.pth.tar"
LOSS = 0.4

torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)
  • 加载常规检查点
PATH = "your_checkpoint.pth.tar"

model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH, weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

在运行推理之前,您必须调用 model.eval() 将 dropout 和批量规范化层设置为评估模式。如果不这样做,将产生不一致的推理结果。如果您希望恢复训练,请调用 model.train() 以确保这些层处于训练模式。恭喜!您已成功保存并加载了 PyTorch 中推理和/或恢复训练的常规检查点。 

当保存一个通用的检查点(checkpoint)时,无论是用于继续训练还是预测,都需要保存更多的信息,不仅仅是 state_dict ,比如说优化器的 state_dict 也是非常重要的,它包含了用于模型训练时需要更新的参数和缓存信息,还可以保存的信息包括 epoch,即中断训练的批次,最后一次的训练 loss,额外的 torch.nn.Embedding 层等等。

上述保存代码就是介绍了如何保存这么多种信息,通过用一个字典来进行组织,然后继续调用 torch.save 方法,一般保存的文件后缀名是 .tar 

加载代码也如上述代码所示,首先需要初始化模型和优化器,然后加载模型时分别调用 torch.load 加载对应的 state_dict 。然后通过不同的键来获取对应的数值。

加载完后,根据后续步骤,调用 model.eval() 用于预测,model.train() 用于恢复训练。

4.3 注意事项

  • 在加载模型之前,你需要确保你已经定义了与保存时完全相同的模型类。
  • 如果你改变了模型结构(例如,添加了新的层或修改了层的参数),那么你可能无法直接加载旧的state_dict,因为它与新的模型结构不匹配。在这种情况下,你可能需要手动修改state_dict或重新训练模型。
  • 优化器的状态也需要在与保存时相同的优化器设置下加载,包括学习率、动量等参数。
  • 如果你的模型或优化器使用了GPU,并且你在加载时也想在GPU上运行它们,那么你需要在加载之前将模型和优化器移动到GPU上(使用.to(device)方法)。

五、深度学习中的.pt和.pth文件的区别

5.1 什么是.pt和.pth文件?

首先,.pt和.pth文件都是用于保存PyTorch模型的文件格式。它们可以用于保存模型的权重、结构以及其他相关信息。

.5.2 pt文件

.pt文件是PyTorch模型保存的通用格式。通常,.pt文件可以保存以下内容:

  • 模型的权重(weights):这是最常见的用法,只保存模型的参数。
  • 完整的模型(full model):包括模型的架构和权重。在这种情况下,可以通过简单的一行代码来加载整个模型。

保存模型权重代码

import torch

# 假设有一个训练好的模型 model
torch.save(model.state_dict(), 'model_weights.pt')

保存完整模型代码

import torch

# 假设有一个训练好的模型 model
torch.save(model, 'full_model.pt')

加载模型权重模型

import torch

# 假设有一个模型类 ModelClass
model = ModelClass()
model.load_state_dict(torch.load('model_weights.pt'))

加载完整模型代码

import torch

model = torch.load('full_model.pt')
model.eval()

5.3 .pth文件


.pth文件在功能上与.pt文件类似,但在实际使用中通常有一些惯例上的区别。具体来说,.pth文件通常用于以下两种情况:

  • 保存和加载模型权重:与.pt文件类似,但在社区中更常用于保存训练好的模型权重。
  • 保存和加载优化器的状态:在训练过程中,我们可能希望保存优化器的状态,以便在训练中断后能够从中断点继续。

保存模型权重代码:

import torch

# 假设有一个训练好的模型 model
torch.save(model.state_dict(), 'model_weights.pth')

保存优化器状态代码:

import torch

# 假设有一个训练好的优化器 optimizer
torch.save(optimizer.state_dict(), 'optimizer_state.pth')

加载模型权重的代码:

import torch

# 假设有一个模型类 ModelClass
model = ModelClass()
model.load_state_dict(torch.load('model_weights.pth'))

加载优化器状态的代码:

import torch

# 假设有一个优化器 optimizer
optimizer.load_state_dict(torch.load('optimizer_state.pth'))

六、ONNX

ONNX(Open Neural Network Exchange)是一个开放格式,用于表示深度学习模型。它使得不同的框架(如TensorFlow, PyTorch, MXNet, Caffe2等)之间可以交换模型,从而促进了模型在不同平台上的部署和优化。ONNX旨在促进机器学习模型的互操作性,让开发者能够更容易地将训练好的模型部署到各种设备上,包括服务器、移动设备、嵌入式设备等。

6.1 ONNX的主要特点:

  1. 开放性和标准化:ONNX是一个开放的格式,由多个组织和公司共同维护,确保了模型的互操作性和标准化。

  2. 跨平台部署:通过ONNX,开发者可以将模型从一种框架转换到另一种框架,甚至在不同的硬件上部署,如CPU、GPU、FPGA等。

  3. 模型优化:ONNX提供了模型优化的工具,可以帮助开发者在不影响模型精度的前提下,减小模型大小、提高推理速度。

  4. 社区支持:ONNX有一个活跃的社区,提供了大量的工具和库来支持模型的转换、优化和部署。

6.2 ONNX的工作流程:

  1. 模型训练:首先,在深度学习框架(如PyTorch或TensorFlow)中训练模型。

  2. 模型导出:将训练好的模型导出为ONNX格式。大多数主流框架都提供了将模型导出为ONNX格式的API。

  3. 模型转换(可选):如果需要,可以使用ONNX的转换工具将ONNX模型转换为其他框架或硬件特定的格式。

  4. 模型优化:使用ONNX的优化工具对模型进行优化,以提高性能。

  5. 模型部署:将优化后的模型部署到目标设备上,进行推理或预测。

6.3 pt模型转onnx 

 步骤一:首先你要安装 依赖库:onnx 和 onnxruntime

pip install onnx
pip install onnxruntime 进行安装

步骤二:pytorch模型转换到onnx模型
pytorch 转 onnx 仅仅需要一个函数 torch.onnx.export 

torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)
参数说明:

model——需要导出的pytorch模型
args——模型的输入参数,满足输入层的shape正确即可。
path——输出的onnx模型的位置。例如‘yolov5.onnx’。
export_params——输出模型是否可训练。default=True,表示导出trained model,否则untrained。
verbose——是否打印模型转换信息。default=False。
input_names——输入节点名称。default=None。
output_names——输出节点名称。default=None。
do_constant_folding——是否使用常量折叠,默认即可。default=True。
dynamic_axes——模型的输入输出有时是可变的,如Rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b,3,h,w),batch,height,width是可变的,但是chancel是固定三通道。

三、示例

import torch
import torch.nn as nn
from model import AlexNet
import netron

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 创建模型实例
model = AlexNet(num_classes=5)
model = model.to(device)

# 加载.pth文件
state_dict = torch.load('AlexNet.pt')

model.load_state_dict(state_dict)
model.eval()

input = torch.ones((1, 3, 224, 224)).to(device)


torch.onnx.export(model, input, f='AlexNet.onnx')  # 导出 .onnx 文件
netron.start('AlexNet.onnx')  # 展示结构图

6.4 ONNX的应用场景:

  • 模型部署:将训练好的模型部署到生产环境中,进行实时推理或批量处理。
  • 跨平台迁移:将模型从一种框架迁移到另一种框架,以适应不同的开发环境或硬件需求。
  • 模型优化:在不影响模型精度的前提下,通过优化减小模型大小、提高推理速度。
  • 模型服务化:将模型封装成服务,供其他系统或应用调用。

总之,ONNX为深度学习模型的互操作性和部署提供了强有力的支持,使得开发者能够更加方便地将模型应用到实际场景中。

七、保存的保存与加载

7.1 核心函数

torch.save() 

保存一个序列化(serialized)的目标到磁盘。函数使用了Python的pickle程序用于序列化。模型(models),张量(tensors)和文件夹(dictionaries)都是可以用这个函数保存的目标类型。

torch.save(obj, f, pickle_module=<module '...'>, pickle_protocol=2)

示例:

保存整个模型:

torch.save(model,'save.pt')

只保存训练好的权重:

torch.save(model.state_dict(), 'save.pt')

torch.load() 

用来加载模型。torch.load() 使用 Python 的 解压工具(unpickling)来反序列化 pickled object 到对应存储设备上。首先在 CPU 上对压缩对象进行反序列化并且移动到它们保存的存储设备上,如果失败了(如:由于系统中没有相应的存储设备),就会抛出一个异常。用户可以通过 register_package 进行扩展,使用自己定义的标记和反序列化方法。

torch.load(f, map_location=None, pickle_module=<module 'pickle' from '...'>)

示例:

torch.load('tensors.pt')
 
# Load all tensors onto the CPU
torch.load('tensors.pt', map_location=torch.device('cpu'))
 
# Load all tensors onto the CPU, using a function
torch.load('tensors.pt', map_location=lambda storage, loc: storage)
 
# Load all tensors onto GPU 1
torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
 
# Map tensors from GPU 1 to GPU 0
torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
 
# Load tensor from io.BytesIO object
with open('tensor.pt') as f:
    buffer = io.BytesIO(f.read())
torch.load(buffer)

torch.nn.Module.load_state_dict(state_dict) [source]

使用 state_dict 反序列化模型参数字典。用来加载模型参数。将 state_dict 中的 parameters 和 buffers 复制到此 module 及其子节点中。

torch.nn.Module.load_state_dict(state_dict, strict=True)

示例:

torch.save(model,'save.pt')
model.load_state_dict(torch.load("save.pt"))  #model.load_state_dict()函数把加载的权重复制

7.2 状态字典的保存和加载

在PyTorch中,通常使用torch.save()函数来保存state_dict,而使用load_state_dict()函数来加载state_dict。保存时,可以直接将state_dict作为参数传递给torch.save()函数,并指定保存的路径。加载时,则需要先加载保存的state_dict文件,并将其传递给模型的load_state_dict()函数。

示例

假设我们有一个训练好的模型model,我们可以使用以下代码来保存其state_dict:

torch.save(model.state_dict(), 'model_weights.pth')


 然后,在需要加载这个模型参数的时候,我们可以使用以下代码:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

当需要为预测保存一个模型的时候,只需要保存训练模型的可学习参数即可。采用 torch.save() 来保存模型的状态字典的做法可以更方便加载模型,这也是推荐这种做法的原因。

通常会用 .pt 或者 .pth 后缀来保存模型。

注意事项

  • 在进行预测之前,必须调用 model.eval() 方法来将 dropout 和 batch normalization 层设置为验证模型。否则,只会生成前后不一致的预测结果。
  • load_state_dict() 方法必须传入一个字典对象,而不是对象的保存路径,也就是说必须先反序列化字典对象,然后再调用该方法,也是例子中先采用 torch.load() ,而不是直接 model.load_state_dict(PATH)
  • 在加载state_dict之前,需要确保模型的架构与保存state_dict时使用的架构完全一致。否则,可能会因为参数不匹配而导致加载失败。
  • 如果只需要加载模型的部分参数(例如,在迁移学习中),可以在加载前对state_dict进行修改,只保留需要的参数部分。

保存模型参数转换为保存整个模型:

import torch
import torchvision.models as models

# 创建模型并加载模型参数
model = models.resnet18()
model.load_state_dict(torch.load('model_params.pth'))

# 保存整个模型
torch.save(model, 'whole_model_from_params.pth')

建议使用保存模型参数,而不是保存整个模型

Pytorch 保存和加载模型后缀:.pt 和.pth

pt 文件保存整个 PyTorch 模型,而 .pth 文件只保存模型的参数

7.3 加载/保存整个模型

示例

保存

torch.save(model, PATH)

加载

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

保存和加载模型都是采用非常直观的语法并且都只需要几行代码即可实现。这种实现保存模型的做法将是采用 Python 的 pickle 模块来保存整个模型,这种做法的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是 pickle 并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors 后采用都可能出现错误。

保存整个模型转换为保存模型参数

import torch

# 加载整个模型
loaded_model = torch.load('whole_model.pth')

# 保存模型参数
torch.save(loaded_model.state_dict(), 'model_params.pth')

建议使用保存模型参数,而不是保存整个模型

 7.4 加载和保存一个通用的检查点(Checkpoint)

在PyTorch中,整个模型的保存与加载不仅仅是关于模型参数(即state_dict)的保存与加载,但确实,模型参数是其中最重要的部分。此外,你可能还想保存优化器的状态、调度器(如果有的话)的状态、当前的训练轮次(epoch)等信息,以便能够完全恢复训练过程。

一个完整的Pytorch模型文件,包含了如下参数

  • model_state_dict:模型参数
  • optimizer_state_dict:优化器的状态
  • epoch:当前的训练轮数
  • loss:当前的损失值

保存整个模型

然而,PyTorch本身并不直接提供一个“保存整个模型”的函数,因为模型的定义(即类的结构和层的关系)是Python代码,不是可以被直接保存的数据。但你可以通过保存模型参数、优化器状态、以及其他相关状态来实现类似的效果。

一种常见的做法是将模型参数、优化器状态以及其他信息(如当前epoch数)存储在一个字典中,然后使用torch.save()函数保存这个字典。这通常被称为“检查点”(checkpoint)。

import torch
import torch.nn as nn

# 定义一个简单的模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = Net()

# 保存模型
torch.save({
            'epoch': 10,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, PATH)

保存的模型文件通常是以 .tar 作为后缀名。 

加载整个模型

加载时,你需要先加载检查点文件,并从中提取模型参数和优化器状态。然后,你可以使用这些参数来恢复模型和优化器的状态。

import torch
import torch.nn as nn

# 定义同样的模型结构
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

# 加载模型
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()

当保存一个通用的检查点(checkpoint)时,无论是用于继续训练还是预测,都需要保存更多的信息,不仅仅是 state_dict ,比如说优化器的 state_dict 也是非常重要的,它包含了用于模型训练时需要更新的参数和缓存信息,还可以保存的信息包括 epoch,即中断训练的批次,最后一次的训练 loss,额外的 torch.nn.Embedding 层等等。

上述保存代码就是介绍了如何保存这么多种信息,通过用一个字典来进行组织,然后继续调用 torch.save 方法,一般保存的文件后缀名是 .tar 。

加载代码也如上述代码所示,首先需要初始化模型和优化器,然后加载模型时分别调用 torch.load 加载对应的 state_dict 。然后通过不同的键来获取对应的数值。

加载完后,根据后续步骤,调用 model.eval() 用于预测,model.train() 用于恢复训练。

注意事项

  • 在加载模型之前,你需要确保你已经定义了与保存时完全相同的模型类。
  • 如果你改变了模型结构(例如,添加了新的层或修改了层的参数),那么你可能无法直接加载旧的state_dict,因为它与新的模型结构不匹配。在这种情况下,你可能需要手动修改state_dict或重新训练模型。
  • 优化器的状态也需要在与保存时相同的优化器设置下加载,包括学习率、动量等参数。
  • 如果你的模型或优化器使用了GPU,并且你在加载时也想在GPU上运行它们,那么你需要在加载之前将模型和优化器移动到GPU上(使用.to(device)方法)。

7.5. 在同一个文件保存多个模型

保存模型的示例代码

torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)

加载模型的示例代码

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

当我们希望保存的是一个包含多个网络模型 torch.nn.Modules 的时候,比如 GAN、一个序列化模型,或者多个模型融合,实现的方法其实和保存一个通用的检查点的做法是一样的,同样采用一个字典来保持模型的 state_dict 和对应优化器的 state_dict 。除此之外,还可以继续保存其他相同的信息。

加载模型的示例代码如上述所示,和加载一个通用的检查点也是一样的,同样需要先初始化对应的模型和优化器。同样,保存的模型文件通常是以 .tar 作为后缀名。

7.6 不同设备下保存和加载模型

在GPU上保存模型,在 CPU 上加载模型

保存模型的示例代码

torch.save(model.state_dict(), PATH)

加载模型的示例代码

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

在 CPU 上加载在 GPU 上训练的模型,必须在调用 torch.load() 的时候,设置参数 map_location ,指定采用的设备是 torch.device('cpu'),这个做法会将张量都重新映射到 CPU 上。

在GPU上保存模型,在 GPU 上加载模型

保存模型的示例代码

torch.save(model.state_dict(), PATH)

加载模型的示例代码

device = torch.device('cuda')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH)
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model                     

在 GPU 上训练和加载模型,调用 torch.load() 加载模型后,还需要采用 model.to(torch.device('cuda')),将模型调用到 GPU 上,并且后续输入的张量都需要确保是在 GPU 上使用的,即也需要采用 my_tensor.to(device)

在CPU上保存,在GPU上加载模型

保存模型的示例代码

torch.save(model.state_dict(), PATH)

加载模型的示例代码

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

这次是 CPU 上训练模型,但在 GPU 上加载模型使用,那么就需要通过参数 map_location 指定设备。然后继续记得调用 model.to(torch.device('cuda'))

7.7 保存 torch.nn.DataParallel 模型

保存模型的示例代码

torch.save(model.module.state_dict(), PATH)

torch.nn.DataParallel 是用于实现多 GPU 并行的操作,保存模型的时候,是采用 model.module.state_dict()

加载模型的代码也是一样的,采用 torch.load() ,并可以放到指定的 GPU 显卡上。

7.8 预训练模型

对于计算机视觉的任务,包括物体检测,我们通常很难拿到很大的数据集,在这种情况下重新训练一个新的模型是比较复杂的,并且不容易调整,因此,Fine-tune(微调)是一个常用的选择。所谓Fine-tune是指利用别人在一些数据集上训练好的预训练模型,在自己的数据集上训练自己的模型。

在具体使用时,通常有两种情况,第一种是直接利用torchvision.models中自带的预训练模型,只需要在使用时赋予pretrained参数为True即可。

    >>> from torch import nn
    >>> from torchvision import models
    # 通过torchvision.model直接调用VGG16的网络结构
    >>> vgg = models.vgg16(pretrained=True)

pytorch自带模型网址:
https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-models/

官方预训练模型调用代码:
https://github.com/pytorch/vision/tree/master/torchvision/models

 第二种是如果想要使用自己的本地预训练模型,或者之前训练过的模型,则可以通过model.load_state_dict()函数操作,下面以squeezenet为例。

import torch
import torchvision.models as models
 
# pretrained=True就可以使用预训练的模型
net = models.squeezenet1_1(pretrained=False)
pthfile = r'E:\anaconda\app\envs\luo\Lib\site-packages\torchvision\models\squeezenet1_1.pth'
net.load_state_dict(torch.load(pthfile))
print(net)

主要是把pretrain设成false,然后直接把路径指定到模型所在处,用load_state_dict程序进行加载。

model=models.squeezenet1_0(pretrained=True)#读取预训练模型参数
for param in model.parameters():#冻结所有参数,不被更新
    param.requires_grad = False
model.classifier[1] = nn.Conv2d(512,CL,kernel_size=(1,1),stride=(1,1))#把分类器输出改为分类个数CL
model=model.cuda()#使用GPU加速
print("params to update:")
params_to_update=[]#用来保存需要更新的参数
for name,param in model.named_parameters():
    if param.requires_grad==True:
        params_to_update.append(param)
        print("\t",name)
optimizer=t.optim.Adam(params_to_update,lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()#使用GPU加速
vis=visdom.Visdom(env=u'window')

首先读取预训练模型的所有参数,然后将requires_grad设置为False,不让其更新,然后修改模型的classifier中的[1]层,把Conv2d的输出修改为分类个数。此时model.classifier[1]的参数的requires_grad自动为True,可以被更新。也即是说,squeezenet只训练model.classifier[1]中的参数,其他参数不变。

 但是报错:RuntimeError: shape '[25, 1000]' is invalid for input of size 50

这里的1000仍然是squeezenet最初的分类数,虽然修改了model.classifier[1]的分类数,但其内部的参数self.num_classes还是1000,比较特殊,之前使用resnet18没有遇到这个问题。

 因此还必须修改内部的参数self.num_classes。

model=models.squeezenet1_0(pretrained=True)#读取预训练模型参数
for param in model.parameters():#冻结所有参数,不被更新
    param.requires_grad = False
model.classifier[1] = nn.Conv2d(512,CL,kernel_size=(1,1),stride=(1,1))#把分类器输出改为分类个数CL
model.num_classes = CL #修改self.num_classes
model=model.cuda()#使用GPU加速
print("params to update:")
params_to_update=[]#用来保存需要更新的参数
for name,param in model.named_parameters():
    if param.requires_grad==True:
        params_to_update.append(param)
        print("\t",name)
optimizer=t.optim.Adam(params_to_update,lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()#使用GPU加速
vis=visdom.Visdom(env=u'window')

7.9 .bin格式

.bin文件是一个二进制文件,可以保存Pytorch模型的参数和持久化缓存。.bin文件的大小较小,加载速度较快,因此在生产环境中使用较多。

下面是一个.bin文件的保存和加载示例(注意:也可以使用 .pt .pth 后缀)

保存模型

import torch
import torch.nn as nn

# 定义一个简单的模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = Net()
# 保存参数到.bin文件
torch.save(model.state_dict(), PATH)

加载模型

import torch
import torch.nn as nn

# 定义相同的模型结构
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

# 加载.bin文件
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()

八、模型的保存与加载到底在做什么?

我们在使用pytorch构建模型并且训练完成后,下一步要做的就是把这个模型放到实际场景中应用,或者是分享给其他人学习、研究、使用。因此,我们开始思考一个问题,提供哪些模型信息,能够让对方能够完全复现我们的模型?

  • 模型代码
    • (1)包含了我们如何定义模型的结构,包括模型有多少层/每层有多少神经元等等信息;
    • (2)包含了我们如何定义的训练过程,包括epoch batch_size等参数;
    • (3)包含了我们如何加载数据和使用;
    • (4)包含了我们如何测试评估模型
  • 模型参数:提供了模型代码之后,对方确实能够复现模型,但是运行的参数需要重新训练才能得到,而没有办法在我们的模型参数基础上继续训练,因此对方还希望我们能够把模型的参数也保存下来给对方。
    • (1)包含model.state_dict(),这是模型每一层可学习的节点的参数,比如weight/bias;
    • (2)包含optimizer.state_dict(),这是模型的优化器中的参数;
    • (3)包含我们其他参数信息,如epoch/batch_size/loss等。
  • 数据集
    • (1)包含了我们训练模型使用的所有数据;
    • (2)可以提示对方如何去准备同样格式的数据来训练模型。
  • 使用文档
    • (1)根据使用文档的步骤,每个人都可以重现模型;
    • (2)包含了模型的使用细节和我们相关参数的设置依据等信息。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值