训练PPO算法,出现Windows fatal exception: access violation问题。PPO算法是一个多线程算法,为了加速训练,使用GPU和显存加速计算。该问题不能稳定复现。
为了节约现存,PPO的公共参数
储存在内存中,worker数据存储在显存中,然后可能数据交换有问题,会出现多线程错误,将公共参数同样存放在显存中,问题消失了。
原:
@ray.remote
class ParameterServer:
def __init__(self):
self.params = ActorCritic(n_actions=3, device='cpu')
self.load_net()
def get_params(self):
return {k: v.cpu() for k, v in self.params.state_dict().items()}
现:
@ray.remote
class ParameterServer:
def __init__(self):
self.params = ActorCritic(n_actions=3, device='cuda:0')
self.load_net()
def get_params(self):
return {k: v for k, v in self.params.state_dict().items()}