报错信息:原网站上的代码有一点问题,报错一下信息
Traceback (most recent call last):
File "", line 201, in <module>
bc_agent.learn(expert_s[sample_indices], expert_a[sample_indices])
File "", line 155, in learn
log_probs = torch.log(self.policy(states).gather(1, actions))
RuntimeError: gather(): Expected dtype int64 for index
经过不断查资料,发现gather只能索引int64类型的数据,所以将原来的actions = torch.tensor(actions).view(-1, 1).to(device)
即actions的类型转换为int64 经过输出发现原来的actions为int32类型的
当出现RuntimeError这个错误信息的时候大概率是因为索引类型不对,追根溯源找到应该的索引类型,再找到错误源进行更改即可