前提
SAVING AND LOADING MODELS
当提到保存和加载模型时,有三个核心功能需要熟悉:
1.torch.save:将序列化的对象保存到disk。这个函数使用Python的pickle实用程序进行序列化。使用这个函数可以保存各种对象的模型、张量和字典。
2.torch.load:使用pickle unpickle工具将pickle的对象文件反序列化为内存。
3.torch.nn.Module.load_state_dict:使用反序列化状态字典加载model’s参数字典。
一:WHAT IS A STATE_DICT
在PyTorch中,torch.nn.Module的可学习参数(即权重和偏差),模块模型包含在model's参数中(通过model.parameters()访问)。state_dict是个简单的Python dictionary对象,它将每个层映射到它的参数张量。
注意,只有具有可学习参数的层(卷积层、线性层等)才有model's state_dict中的条目。优化器对象(connector .optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。
Example:
import torch
import torch.nn as nn
import torch.nn.functional as F
# Define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass,self).__init__()
self.conv1=nn.Conv2d(3,6,5)