离散动作的多智能体博弈问题讨论
用MADSPG做两个智能体的追逐出现问题
最近在尝试做多智能体博弈,就自己写了个简单的追逐环境。我的环境中,两个智能体都是观察自己周围两格的信息包括自身,也就是状态为1*25的张量,动作为上下左右四个离散动作。所以Actor输出四个动作的概率,进行选择动作,在选择动作时加入gumble噪声。并且加入SAC的最大熵思想,让每个动作的概率尽可能的大一点。
网络为常规的A-C网络,一个知道两个智能体信息的评论家Critic,两个智能体执行者Actor。具体训练代码如下:
#update net models
def learn(self,batch_size):
if len(self.experience) < batch_size*20:
return
self.time += 1
samples = random.sample(self.experience,batch_size)#get samples from the replay pool #采样
s0_total,a0_total,r_total,s1_total,done = zip(*samples)
s0_total = torch.tensor(s0_total,dtype = torch.float).cuda()
a0_total = torch.tensor(a0_total,dtype = torch.float).cuda()
r_total = torch.tensor(r_total,dtype = torch.float).cuda()
s1_total = torch.tensor(s1_total,dtype = torch.float).cuda()
done = torch.tensor(done,dtype = torch.float).cuda()
#update critic model
def update_critic():
#get the next action
a1_predator = self.t_actor_predator(s1_total[:,0:25]).detach()#get the next action of predator #追逐者的预测动作
a1_prey = self.t_actor_prey(s1_total[:,25:50]).detach()#get the next action of prey #逃亡者的预测动作
a1_total = torch.cat([a1_predator,a1_prey],1).detach() #将他们的动作连接起来
# 目标Q价值
q_target = r_total + self.gamma * self.critic(s1_total,a1_total).detach() * (1-done)
#当前预测的Q价值
q_current = self.critic(s0_total,a0_total)
loss_fn = nn.MSELoss()
loss = loss_fn(q_current,q_target)
self.optim_critic.zero_grad()
loss.backward()
self.optim_critic.step()
#训练演员网络
def update_actor():
a0_predaotr = self.actor_predator(s0_total[:,0:25]) #得到追逐者的动作
a0_prey = self.actor_prey(s0_total[:,25:50]) #逃亡者的动作
a0_t = torch.cat([a0_predaotr,a0_prey],1)
loss=-torch.mean(self.critic(s0_total,a0_t) ) #它们所得的平价价值,加 - 是为做梯度上升
self.optim_a_predator.zero_grad()
self.optim_a_prey.zero_grad()
loss.backward()
self.optim_a_predator.step()
self.optim_a_prey.step()
#软更新
def soft_update(t_net,net):
for t_param,param in zip(t_net.parameters(),net.parameters()):
t_param.data.copy_(t_param.data * self.tau + param.data)
update_critic()
update_actor()
# if self.time % 50 == 0:
soft_update(self.t_critic,self.critic)
soft_update(self.t_actor_predator,self.actor_predator)
soft_update(self.t_actor_prey,self.actor_prey)
但是达不到效果,请问是问题出在哪儿呢?