用于电力系统储能控制的双智能体PPO算法python代码

        我是刚入门深度强化学习,对于代码结构不太清楚。我让gpt生成了一个双智能体的PPO代码(如下),想请教各位大佬这个代码结构上有什么问题吗?

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.distributions import Categorical
import gym

# 定义双智能体的 ActorCritic 网络
class ActorCritic(nn.Module):
    def __init__(self, input_size, action_size):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.actor = nn.Linear(64, action_size)
        self.critic = nn.Linear(64, 1)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        action_probs = F.softmax(self.actor(x), dim=-1)
        state_value = self.critic(x)
        return action_probs, state_value

# 定义 PPO 算法
class PPO:
    def __init__(self, input_size, action_size):
        self.net = ActorCritic(input_size, action_size)
        self.optimizer = optim.Adam(self.net.parameters(), lr=0.001)
        self.eps_clip = 0.2
        self.value_coef = 0.5
        self.entropy_coef = 0.01
    
    def get_action(self, state):
        state = torch.FloatTensor(state)
        action_probs, _ = self.net(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        return action.item()
    
    def update(self, state_batch, action_batch, old_action_probs, state_value_batch, returns, advantages):
        # 计算新的动作概率、状态值和分布熵
        action_probs, state_values = self.net(state_batch)
        dist = Categorical(action_probs)
        entropy = dist.entropy().mean()
        
        # 计算新旧动作概率的比例
        new_action_probs = dist.log_prob(action_batch)
        old_action_probs = torch.FloatTensor(old_action_probs)
        action_ratios = torch.exp(new_action_probs - old_action_probs)
        
        # 计算 surrogate loss
        surr1 = action_ratios * advantages.view(-1, 1)
        surr2 = torch.clamp(action_ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages.view(-1, 1)
        actor_loss = -torch.min(surr1, surr2).mean()
        
        # 计算 critic loss
        critic_loss = F.mse_loss(state_values.view(-1, 1), state_value_batch)
        
        # 计算总的损失函数
        loss = actor_loss + self.value_coef * critic_loss - self.entropy_coef * entropy
        
        # 执行优化
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

# 定义训练函数
def train(env, total_epochs, max_timesteps):
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    ppo = PPO(state_size, action_size)
    
    for epoch in range(total_epochs):
        state = env.reset()
        episode_reward = 0
        done = False
        t = 0
        
        while not done and t < max_timesteps:
            action = ppo.get_action(state)
            next_state, reward, done, _ = env.step(action)
            
            action_probs, state_value = ppo.net(torch.FloatTensor(state))
            old_action_probs = action_probs[0][action].item()
            
            state_batch = torch.FloatTensor([state])
            action_batch = torch.LongTensor([action])
            old_action_probs_batch = [old_action_probs]
            state_value_batch = torch.FloatTensor([state_value.item()])
            
            return_batch = torch.FloatTensor([reward])
            advantage_batch = return_batch - state_value_batch
            
            ppo.update(state_batch, action_batch, old_action_probs_batch, state_value_batch, return_batch, advantage_batch)
            
            state = next_state
            episode_reward += reward
            t += 1
        
        print(f"Epoch {epoch}, Reward: {episode_reward}")
    
    env.close()

# 创建环境并训练
env = gym.make('CartPole-v1')
total_epochs = 1000
max_timesteps = 200
train(env, total_epochs, max_timesteps)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值