PPO的代码实现

整个PPO2的代码实现流程如下

待完善:
-[ ]训练
-[ ]游戏环境,后面会运行跑一下马里奥
基本流程是如下的

  • 建立网络的类actor-critic
    • 一个是actor网络,输入是state,输出是action
    • 一个是critic网络,输入是state,输出是reward
  • PPO整体的类。后续会完整介绍一下代码
import gym
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt


class policy_net(torch.nn.Module):
    def __init__(self,state_dim,action_dim,hidden_dim):
        super(policy_net,self).__init__()
        self.f1 = torch.nn.Linear(state_dim,hidden_dim)
        self.f2 = torch.nn.Linear(hidden_dim,action_dim)
    def forward(self,x):
        x = F.relu(self.f1(x))
        return F.softmax(self.f2(x),dim =1)

class value_net(torch.nn.Module):
    def __init__(self,state_dim,hidden_dim):
        super(value_net,self).__init__()
        self.f1 = torch.nn.Linear(state_dim,hidden_dim)
        self.f2 = torch.nn.Linear(hidden_dim,1)
    def forward(self,x):
        x = F.relu(self.f1(x))
        return self.f2(x)

class PPO:
    def __init__(self,state_dim,action_dim,hidden_dim,lr_p,lr_v,lmbda,epochs,eps,gamma,device):


        self.action_net = policy_net(state_dim,action_dim,hidden_dim)
        self.critic_net = value_net(state_dim,hidden_dim)
        self.actor_opt = torch.optim.Adam(self.action_net.parameters(),lr=lr_a)

        self.cri_opt = torch.optim.Adam(self.critic_net.parameters(),lr=lr_c)
        self.lr_a = lr_p
        self.lr_c = lr_v

        self.device = device
        self.gamma = gamma
        self.lmbda = lmbda
        self.epochs = epochs
        self.eps = eps #截断的数值
        
    def take_action(self,state):
        state = torch.tensor([state],torch.float).to(self.device)
        prob = self.action_net(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def update(self,tmp):
        states = torch.tensor(tmp['states'],dtype = torch.float).to(self.device)
        rewards = torch.tensor(tmp['rewards'],dtype = torch.float).view(-1,1).to(self.device)
        actions = torch.tensor(tmp['actions'],dtype = torch.float).views(-1,1).to(self.device)
                
        dones= torch.tensor(tmp['dones'],dtype = torch.float).views(-1,1).to(self.device)
        next_states = torch.tensor(tmp['next_states'],dtype = torch.float).to(self.device)
        td_target = rewards + self.gamma *self.critic_net(next_states) *(1-dones)
        td_delta = td_target - critic_net(states)
        adv = self.compute_advantage(gamma=self.gamma,lmbda = self.lmbda,td_delta)


        old_log_probs = torch.log(self.action_net(states).gather(1,actions)).detach()

        for _ in range(self.epochs):
            log_probs = torch.log(self.action_net(states).gather(1,actions))
            ratio = torch.exp(log_probs - old_log_probs)
            surr1 =ratio * adv
            surr2 = torch.clamp(ratio,1-self.eps,1+self.eps)*adv

            actor_loss = torch.mean(-torch.min(surr1,surr2))
            cri_loss = torch.mean(F.mse_loss(self.critic_net(state),td_target.detach()))

            self.actor_opt.zero_grad()
            self.critic_net.zero_grad()

            actor_loss.backward()
            cri_loss,backward()
            self.actor_opt.step()
            self.cri_opt.step()

    


            




    def compute_advantage(gamma,lmbda,td_delta):
        td_delta = td_delta.detach().numpy()
        advantage_list = []
        advantage = 0.0
        for delta in td_delta[::1]:
            advantage = gamma*lmbda*advantage + delta
            advantage_list.append(advantage)
        advantage_list.reverse()# 翻转
        return torch.tensor(advantage_list,dtype = torch.float)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值