以下为我个人的代码
critic_model = agent.critic.state_dict()
actor_model = agent.actor.state_dict()
agent.actor.load_state_dict(torch.load(args.actor_save_path))
agent.critic.load_state_dict(torch.load(args.critic_save_path))
在读取torch储存的模型时,遇到了AttributeError: '****' object has no attribute 'copy'的问题 ,原因是我没有将整个模型保存并加载。
处理的方法为:
critic_model = agent.critic
actor_model = agent.actor
agent.actor = torch.load(args.actor_save_path)
agent.critic = torch.load(args.critic_save_path)