pytorch中模型参数个数不定时如何加载模型

pytorch模型保存的两种方法

假设实例化的模型为model,类为Model,Path表示模型的路径

  1. 如果模型中的参数随着程序的运行而变化可使用
	# 保存
	torch.save(model ,Path)

	# 加载
	# 这里model不用初始化为model = Model(),但是一定要首先引入Model类
	import Model # 如果没有这个类的化,下面的语句会报错!
	model = torch.load(Path)
  1. 如果模型中的参数数量随着程序运行不发生变化,则可使用:
	# 保存
	torch.save(model.state_dict(),Path)
	
	# 加载
	model = Model()
	# 这里注意加载的模型参数数量要与Model()初始化的模型参数数量要一置,不然会报错
	model.load_dict_state(torch.load(Path))

下面我举一个参数会发生变化的模型,其中1,2是相互配套的代码

import torch
import torch.nn as nn

class A(nn.Module):
	def __init__(self):
		nn.Module.__init__(self)
		self.register_buffer("aaa", torch.tensor([]))
	def forward(self,x):
		self.aaa.data = torch.cat((self.aaa,torch.tensor([x])),0)

b = A()																	# ----1
b.load_state_dict(torch.load('./fucc.pt'))	# ----1
# b = torch.load('./fucc.pth')						# ----2

for i in range (20):
	b(i)
torch.save(b.state_dict(),'fucc.pt')			# ----1
# torch.save(b,'fucc.pth')							# ----2

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值