2020-12-03

离散动作的多智能体博弈问题讨论

用MADSPG做两个智能体的追逐出现问题

最近在尝试做多智能体博弈,就自己写了个简单的追逐环境。我的环境中,两个智能体都是观察自己周围两格的信息包括自身,也就是状态为1*25的张量,动作为上下左右四个离散动作。所以Actor输出四个动作的概率,进行选择动作,在选择动作时加入gumble噪声。并且加入SAC的最大熵思想,让每个动作的概率尽可能的大一点。

网络为常规的A-C网络,一个知道两个智能体信息的评论家Critic,两个智能体执行者Actor。具体训练代码如下:
#update net models
def learn(self,batch_size): 
    if len(self.experience) < batch_size*20:
        return
    self.time += 1
    samples = random.sample(self.experience,batch_size)#get samples from the replay pool   #采样
    s0_total,a0_total,r_total,s1_total,done = zip(*samples)
    
    s0_total = torch.tensor(s0_total,dtype = torch.float).cuda()
    a0_total = torch.tensor(a0_total,dtype = torch.float).cuda()
    r_total = torch.tensor(r_total,dtype = torch.float).cuda()
    s1_total = torch.tensor(s1_total,dtype = torch.float).cuda()
    done = torch.tensor(done,dtype = torch.float).cuda()

    #update critic model
    def update_critic():
        #get the next action
        a1_predator = self.t_actor_predator(s1_total[:,0:25]).detach()#get the next action of predator   #追逐者的预测动作
        a1_prey = self.t_actor_prey(s1_total[:,25:50]).detach()#get the next action of prey   #逃亡者的预测动作
        a1_total = torch.cat([a1_predator,a1_prey],1).detach()  #将他们的动作连接起来
       
        # 目标Q价值
        q_target = r_total + self.gamma * self.critic(s1_total,a1_total).detach() * (1-done)
        #当前预测的Q价值
        q_current = self.critic(s0_total,a0_total)

        loss_fn = nn.MSELoss()
        loss = loss_fn(q_current,q_target)
        self.optim_critic.zero_grad()
        loss.backward()
        self.optim_critic.step()

    #训练演员网络
    def update_actor():
        a0_predaotr = self.actor_predator(s0_total[:,0:25])  #得到追逐者的动作
        a0_prey = self.actor_prey(s0_total[:,25:50]) #逃亡者的动作
        a0_t = torch.cat([a0_predaotr,a0_prey],1)  
        loss=-torch.mean(self.critic(s0_total,a0_t)  )  #它们所得的平价价值,加 - 是为做梯度上升     
        self.optim_a_predator.zero_grad()
        self.optim_a_prey.zero_grad()
        loss.backward()
        self.optim_a_predator.step()
        self.optim_a_prey.step()

    #软更新
    def soft_update(t_net,net):
        for t_param,param in zip(t_net.parameters(),net.parameters()):
            t_param.data.copy_(t_param.data * self.tau + param.data)

    update_critic()
    update_actor()
    # if self.time % 50 == 0:
    soft_update(self.t_critic,self.critic)
    soft_update(self.t_actor_predator,self.actor_predator)
    soft_update(self.t_actor_prey,self.actor_prey)
    但是达不到效果,请问是问题出在哪儿呢?
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值