目录
前提
本次对模型数据的保存和读取进行了分离,分别配属于不同的文件,便于管理和使用,相关文件自行下载参考
模型数据:
class CNN(nn.Module):
def __init__(self): # 输入大小 (3, 256, 256)
super(CNN, self).__init__()
self.conv1 = nn.Sequential( #将多个层组合成一起。
nn.Conv2d( #2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据
in_channels=3, # 图像通道个数,1表示灰度图(确定了卷积核 组中的个数),
out_channels=16,# 要得到几多少个特征图,卷积核的个数
kernel_size=5, # 卷积核大小,5*5
stride=1, # 步长
padding=2, # 一般希望卷积核处理后的结果大小与处理前的数据大小相同,效果会比较好。那padding改如何设计呢?建议stride为1,kernel_size = 2*padding+1
), # 输出的特征图为 (16, 256, 256)
nn.ReLU(), # relu层
nn.MaxPool2d(kernel_size=2), # 进行池化操作(2x2 区域), 输出结果为: (16, 128, 128)
)
self.conv2 = nn.Sequential( #输入 (16, 128, 128)
nn.Conv2d(16, 32, 5, 1, 2), # 输出 (32, 128, 128)
nn.ReLU(), # relu层
nn.Conv2d(32, 32, 5, 1, 2), # 输出 (32, 128, 128)
nn.ReLU(),
nn.MaxPool2d(2), # 输出 (32, 64, 64)
)
self.conv3 = nn.Sequential( #输入 (32, 64, 64)
nn.Conv2d(32, 64, 5, 1, 2),
nn.ReLU(), # 输出 (64, 64, 64)
)
self.out = nn.Linear(64 * 64 * 64, 20) # 全连接层得到的结果
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)# 输出 (64,64, 32, 32)
x = x.view(x.size(0), -1) # flatten操作,结果为:(batch_size, 64 * 32 * 32)
output = self.out(x)
return output
一、为什么保存模型的数据?
在许多应用场景中,往往需要在不同的时间、地点和环境中使用同一个模型。如果模型参数存储在硬盘上,就可以随时加载并使用,非常方便。而且重新训练一个深度学习模型可能需要花费数小时,甚至数天的时间,这取决于模型的复杂性和数据集的大小。因此,保存模型可以避免在每个项目或每次需要使用模型时重新训练。
在运行深度学习模型时,往往需要大量的内存(RAM)。如果模型的参数存储在硬盘上而不是内存中,可以大大减少内存的使用,这在内存有限的系统上尤其有用。在深度学习的迁移学习中,一个已经在一个任务上训练好的模型被用作新任务的基础模型。对新任务只需要训练部分特定的层,这样可以大大节省计算资源并提高效率。保存模型使得这种训练方式变得可能。
二、模型数据保存
1、保存模型参数
torch.save(model.state_dict(), path)
保存模型的参数,model.state_dict()
返回一个包含模型所有参数的字典对象,包括卷积层的权重和偏置,全连接层的权重和偏置等。这个字典对象可以被加载回模型以进行后续的训练或者评估。
2、保存完整模型
torch.save(model, 'best.pt')
将完整的模型数据保存到‘best.pt’文件中,实现模型的存储。
三、模型数据读取
在生成保存文件后,如果想要读取该数据,但又不想在原文件中进行读取,可以使用下述方法
import torch
model = torch.load('best.pt') #加载模型数据文件
model.eval() #固定模型数据和参数,防止后面被修改
print(model)
使用 'torch.load('best.pt')' 加载了之前保存的‘best.pt’模型'。model.eval()' 是将模型设置为评估模式,这在模型训练完毕并加载之后非常常见,它主要用于模型的预测或者评估阶段。最后,'print(model)' 是打印出模型的结构或者参数等信息。
注:此代码可单独运行,但必须与所加载的文件处于同一目录下
读取结果展示:
四、总结
torch.load
和torch.save
是PyTorch库中的两个用于模型加载和保存的函数。torch.save
函数用于将模型、张量或其他PyTorch对象保存到文件或路径中,可以选择不同的序列化方式。torch.load
函数用于从文件中或路径中加载已保存的模型、张量或其他PyTorch对象。可以在不同文件中实现对模型数据的保存和加载。