import torch
import torch.nn as nn
import torch.nn.functional as F
# Define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass,self).__init__()
self.fc1=nn.Linear(10,5)
self.fc2=nn.Linear(5,2)
def farward(self,x):
x=F.relu(self.fc1(x))
x=self.fc2(x)
return x
# Initialize model
model=TheModelClass()
# Initialize optimizer
optimizer=torch.optim.SGD(model.parameters(),lr=1e-4,momentum=0.9)
print("Model's state_dict:")
# Print model's state_dict
for param_tensor in model.state_dict():
print(param_tensor,"\t",model.state_dict()[param_tensor].size())
print("optimizer's state_dict:")
# Print optimizer's state_dict
for var_name in optimizer.state_dict():
print(var_name,"\t",optimizer.state_dict()[var_name])
"""save state_dict"""
PATH = './model.pt' # 自动生成文件
torch.save(model.state_dict(), PATH)
model2 = TheModelClass()
model2.load_state_dict(torch.load(PATH))
model2.eval()
"""save/load Entire Model"""
PATH = './enmodel.pt'
torch.save(model, PATH)
model3 = TheModelClass()
model = torch.load(PATH)
model.eval()
pytorch save model
最新推荐文章于 2024-04-28 19:15:00 发布