网络结构中含有参数(h,c)
class Q_net(nn.Module):
def __init__(self, state_space=None,
action_space=None):
super(Q_net, self).__init__()
# space size check
assert state_space is not None, "None state_space input: state_space should be selected."
assert action_space is not None, "None action_space input: action_space should be selected."
self.hidden_space = 64
self.state_space = state_space
self.action_space = action_space
self.Linear1 = nn.Linear(self.state_space, self.hidden_space)
self.lstm = nn.LSTM(self.hidden_space,self.hidden_space, batch_first=True)
self.Linear2 = nn.Linear(self.hidden_space, self.action_space)
def forward(self, x, h, c):
x = F.relu(self.Linear1(x))
x, (new_h, new_c) = self.lstm(x,(h,c))
x = self.Linear2(x)
return x, new_h, new_c
def init_hidden_state(self, batch_size, training=None):
assert training is not None, "training step parameter should be dtermined"
if training is True:
return torch.zeros([1, batch_size, self.hidden_space]), torch.zeros([1, batch_size, self.hidden_space])
else:
return torch.zeros([1, 1, self.hidden_space]), torch.zeros([1, 1, self.hidden_space])
//参数h、c
h, c = q_net.init_hidden_state(batch_size=batch_size, training=True)
.pth转.onnx报错: TypeError: forward() missing 2 required positional argument, 报错代码如下:
dummy_input = torch.randn(64,2)
checkpoing = torch.load('./DRQN_POMDP_Random_SEED_1.pth', 'cpu')
model1.load_state_dict(checkpoing)
torch.onnx.export(checkpoing, dummy_input, "model_best.onnx", export_params=True, verbose=True) # 将模型保存成.onnx格
最后的解决办法:将h、c代入模型中
dummy_input = torch.randn(1,64,2) # 要求输入3维的矩阵, why?
model1 = Q
h, c = model1.init_hidden_state(batch_size=batch_size, training=False)
checkpoing = torch.load('./DRQN_POMDP_Random_SEED_1.pth', 'cpu') # 导入模型参数
model1.load_state_dict(checkpoing) # 将模型参数赋予自定义的模型
torch.onnx.export(model1, (dummy_input,h,c), "model_best.onnx", export_params=True, verbose=True) # 将模型保存成.onnx格