Pytorch 中的模型结构及权重文件
1. 模型结构与参数的保存与加载 (To save both the architecture and weights)
将模型结构及参数都保存起来,在后续模型加载时不需要预先定义模型的网络结构,这种方法使用方便,但耗内存。
具体方法如下:
1) 保存
import torch
import torchvision.models as models
# Assuming you have a model instance
model = models.resnet18()
# Save the entire model, including architecture and weights
model_path = 'path/to/your/model.pth'
torch.save(model, model_path)
2)加载:
model = torch.load('xxx.pth', map_location='cpu')
2. 仅仅保存模型参数
将模型的参数保存起来,后续模型加载时需要预先定义模型的网络结构(实例化模型,To visualize the architecture, you would need to create an instance of the model and load the state dictionary into it),具体方法如下:
1)保存:
import torch
import torchvision.models as models
# Assuming you have a model instance
model = models.resnet18()
# Save only the model's state dictionary (weights)
state_dict_path = 'path/to/your/model_state_dict.pth'
torch.save(model.state_dict(), state_dict_path)
2)加载:
# Create an instance of the model
loaded_model = models.resnet18()
# Load the state dictionary into the model
state_dict_path = 'path/to/your/model_state_dict.pth'
loaded_model.load_state_dict(torch.load(state_dict_path))
# Display the architecture of the loaded model
summary(model, input_size=(input_channels, input_height, input_width))
3. 查看权重
#1. print all the parameters of the weight
for name, param in model.named_parameters():
if param.requires_grad:
print(name, param.data)
# 2. just check the weights of the first layers, here I try to print the first five layers
count = 0
for name, param in model.named_parameters():
if param.requires_grad:
print(name, param.data)
count += 1
if count >= 5:
break
4. 模型文件不同后缀名的区别
常见的有三种保存模型的文件,后缀名分别为.pt, .pth, .pkl,但它们没有任何区别。以下是来自ChatGPT的解释:
In PyTorch, the file extensions .pt, .pth, and .pkl are commonly used for saving and loading models and related objects. However, there isn’t a strict convention for these extensions, and their usage can depend on the developer’s preference.