A2C原理和代码实现

参考王树森《深度强化学习》课程和书籍


1、A2C原理:

在这里插入图片描述


Observe a transition: ( s t , a t , r t , s t + 1 ) (s_t,{a_t},r_t,s_{t+1}) (st,at,rt,st+1)

TD target:
y t = r t + γ ⋅ v ( s t + 1 ; w ) . y_{t} = r_{t}+\gamma\cdot v(s_{t+1};\mathbf{w}). yt=rt+γv(st+1;w).
TD error:
δ t = v ( s t ; w ) − y t . \quad\delta_t = v(s_t;\mathbf{w})-y_t. δt=v(st;w)yt.
Update the policy network (actor) by:
θ ← θ − β ⋅ δ t ⋅ ∂ ln ⁡ π ( a t ∣ s t ; θ ) ∂ θ . \mathbf{\theta}\leftarrow\mathbf{\theta}-\beta\cdot\delta_{t}\cdot\frac{\partial\ln\pi(a_{t}\mid s_{t};\mathbf{\theta})}{\partial \mathbf{\theta}}. θθβδtθlnπ(atst;θ).


def compute_value_loss(self, bs, blogp_a, br, bd, bns):
    # 目标价值。
    with torch.no_grad():
        target_value = br + self.args.discount * torch.logical_not(bd) * self.V_target(bns).squeeze()
        # torch.logical_not 对输入张量取逻辑非

    # 计算value loss。
    value_loss = F.mse_loss(self.V(bs).squeeze(), target_value)
    return value_loss

Update the value network (critic) by:
w ← w − α ⋅ δ t ⋅ ∂ v ( s t ; w ) ∂ w . \mathbf{w}\leftarrow\mathbf{w}-\alpha\cdot\delta_{t}\cdot{\frac{\partial{v(s_{t}};\mathbf{w})}{\partial\mathbf{w}}}. wwαδtwv(st;w).


def compute_policy_loss(self, bs, blogp_a, br, bd, bns):
    # 建议对比08_a2c.py,比较二者的差异。
    with torch.no_grad():
        value = self.V(bs).squeeze()

    policy_loss = 0
    for i, logp_a in enumerate(blogp_a):
        policy_loss += -logp_a * value[i]
    policy_loss = policy_loss.mean()
    return policy_loss

2、A2C完整代码实现:

参考后修改注释:最初的代码在https://github.com/wangshusen/DRL

"""8.3节A2C算法实现。"""
import argparse
import os
from collections import defaultdict
import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical


class ValueNet(nn.Module):
    def __init__(self, dim_state):
        super().__init__()
        self.fc1 = nn.Linear(dim_state, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 1)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class PolicyNet(nn.Module):
    def __init__(self, dim_state, num_action):
        super().__init__()
        self.fc1 = nn.Linear(dim_state, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, num_action)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        prob = F.softmax(x, dim=-1)
        return prob


class A2C:
    def __init__(self, args):
        self.args = args
        self.V = ValueNet(args.dim_state)
        self.V_target = ValueNet(args.dim_state)
        self.pi = PolicyNet(args.dim_state, args.num_action)
        self.V_target.load_state_dict(self.V.state_dict())

    def get_action(self, state):
        probs = self.pi(state)
        m = Categorical(probs)
        action = m.sample()
        logp_action = m.log_prob(action)
        return action, logp_action

    def compute_value_loss(self, bs, blogp_a, br, bd, bns):
        # 目标价值。
        with torch.no_grad():
            target_value = br + self.args.discount * torch.logical_not(bd) * self.V_target(bns).squeeze()

        # 计算value loss。
        value_loss = F.mse_loss(self.V(bs).squeeze(), target_value)
        return value_loss

    def compute_policy_loss(self, bs, blogp_a, br, bd, bns):
        # 目标价值。
        with torch.no_grad():
            target_value = br + self.args.discount * torch.logical_not(bd) * self.V_target(bns).squeeze()

        # 计算policy loss。
        with torch.no_grad():
            advantage = target_value - self.V(bs).squeeze()
        policy_loss = 0
        for i, logp_a in enumerate(blogp_a):
            policy_loss += -logp_a * advantage[i]
        policy_loss = policy_loss.mean()
        return policy_loss

    def soft_update(self, tau=0.01):
        def soft_update_(target, source, tau_=0.01):
            for target_param, param in zip(target.parameters(), source.parameters()):
                target_param.data.copy_(target_param.data * (1.0 - tau_) + param.data * tau_)

        soft_update_(self.V_target, self.V, tau)


class Rollout:
    def __init__(self):
        self.state_lst = []
        self.action_lst = []
        self.logp_action_lst = []
        self.reward_lst = []
        self.done_lst = []
        self.next_state_lst = []

    def put(self, state, action, logp_action, reward, done, next_state):
        self.state_lst.append(state)
        self.action_lst.append(action)
        self.logp_action_lst.append(logp_action)
        self.reward_lst.append(reward)
        self.done_lst.append(done)
        self.next_state_lst.append(next_state)

    def tensor(self):
        bs = torch.as_tensor(self.state_lst).float()
        ba = torch.as_tensor(self.action_lst).float()
        blogp_a = self.logp_action_lst
        br = torch.as_tensor(self.reward_lst).float()
        bd = torch.as_tensor(self.done_lst)
        bns = torch.as_tensor(self.next_state_lst).float()
        return bs, ba, blogp_a, br, bd, bns


class INFO:
    def __init__(self):
        self.log = defaultdict(list)
        self.episode_length = 0
        self.episode_reward = 0
        self.max_episode_reward = -float("inf")

    def put(self, done, reward):
        if done is True:
            self.episode_length += 1
            self.episode_reward += reward
            self.log["episode_length"].append(self.episode_length)
            self.log["episode_reward"].append(self.episode_reward)

            if self.episode_reward > self.max_episode_reward:
                self.max_episode_reward = self.episode_reward

            self.episode_length = 0
            self.episode_reward = 0

        else:
            self.episode_length += 1
            self.episode_reward += reward


def train(args, env, agent: A2C):
    V_optimizer = torch.optim.Adam(agent.V.parameters(), lr=3e-3)
    pi_optimizer = torch.optim.Adam(agent.pi.parameters(), lr=3e-3)
    info = INFO()

    rollout = Rollout()
    state, _ = env.reset()
    for step in range(args.max_steps):
        action, logp_action = agent.get_action(torch.tensor(state).float())
        next_state, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated
        info.put(done, reward)

        rollout.put(
            state,
            action,
            logp_action,
            reward,
            done,
            next_state,
        )
        state = next_state

        if done is True:
            # 模型训练。
            bs, ba, blogp_a, br, bd, bns = rollout.tensor()

            value_loss = agent.compute_value_loss(bs, blogp_a, br, bd, bns)
            V_optimizer.zero_grad()
            value_loss.backward(retain_graph=True)
            V_optimizer.step()

            policy_loss = agent.compute_policy_loss(bs, blogp_a, br, bd, bns)
            pi_optimizer.zero_grad()
            policy_loss.backward()
            pi_optimizer.step()

            agent.soft_update()

            # 打印信息。
            info.log["value_loss"].append(value_loss.item())
            info.log["policy_loss"].append(policy_loss.item())

            episode_reward = info.log["episode_reward"][-1]
            episode_length = info.log["episode_length"][-1]
            value_loss = info.log["value_loss"][-1]
            print(f"step={step}, reward={episode_reward:.0f}, length={episode_length}, max_reward={info.max_episode_reward}, value_loss={value_loss:.1e}")

            # 重置环境。
            state, _ = env.reset()
            rollout = Rollout()

            # 保存模型。
            if episode_reward == info.max_episode_reward:
                save_path = os.path.join(args.output_dir, "model.bin")
                torch.save(agent.pi.state_dict(), save_path)

        if step % 10000 == 0:
            plt.plot(info.log["value_loss"], label="value loss")
            plt.legend()
            plt.savefig(f"{args.output_dir}/value_loss.png", bbox_inches="tight")
            plt.close()

            plt.plot(info.log["episode_reward"])
            plt.savefig(f"{args.output_dir}/episode_reward.png", bbox_inches="tight")
            plt.close()


def eval(args, env, agent):
    agent = A2C(args)
    model_path = os.path.join(args.output_dir, "model.bin")
    agent.pi.load_state_dict(torch.load(model_path))

    episode_length = 0
    episode_reward = 0
    state, _ = env.reset()
    for i in range(5000):
        episode_length += 1
        action, _ = agent.get_action(torch.from_numpy(state))
        next_state, reward, terminated, truncated, info = env.step(action.item())
        done = terminated or truncated
        episode_reward += reward

        state = next_state
        if done is True:
            print(f"episode reward={episode_reward}, length={episode_length}")
            state, _ = env.reset()
            episode_length = 0
            episode_reward = 0


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="CartPole-v1", type=str, help="Environment name.")
    parser.add_argument("--dim_state", default=4, type=int, help="Dimension of state.")
    parser.add_argument("--num_action", default=2, type=int, help="Number of action.")
    parser.add_argument("--output_dir", default="output", type=str, help="Output directory.")
    parser.add_argument("--seed", default=42, type=int, help="Random seed.")

    parser.add_argument("--max_steps", default=100_000, type=int, help="Maximum steps for interaction.")
    parser.add_argument("--discount", default=0.99, type=float, help="Discount coefficient.")
    parser.add_argument("--lr", default=1e-3, type=float, help="Learning rate.")
    parser.add_argument("--batch_size", default=32, type=int, help="Batch size.")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")

    parser.add_argument("--do_train", action="store_true", help="Train policy.")
    parser.add_argument("--do_eval", action="store_true", help="Evaluate policy.")
    args = parser.parse_args()

    env = gym.make(args.env)
    agent = A2C(args)

    if args.do_train:
        train(args, env, agent)

    if args.do_eval:
        eval(args, env, agent)


3、torch.distributions.Categorical()

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs) # 用probs构造一个分布
action = m.sample() # 按照probs进行采样
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward # log_prob 计算log(probs[action])的值
loss.backward()

Probability distributions - torch.distributions — PyTorch 2.0 documentation

next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward # log_prob 计算log(probs[action])的值
loss.backward()


[Probability distributions - torch.distributions — PyTorch 2.0 documentation](https://pytorch.org/docs/stable/distributions.html)

[【PyTorch】关于 log_prob(action) - 简书 (jianshu.com)](https://www.jianshu.com/p/06a5c47ee7c2)
  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
陷波器是一种常用的滤波器,用于滤除特定频率的信号。它的实现原理是利用一个带通滤波器和一个带阻滤波器组成,这两个滤波器的中心频率相同,但它们的带宽和增益是互补的。这样,当它们被级联时,会形成一个只能通过特定频率的滤波器,也就是陷波器。 在实现陷波器时,我们需要选择一个合适的滤波器类型(如Butterworth滤波器或Chebyshev滤波器)以及中心频率和带宽等参数。接下来是C语言代码实现的示例: ```c #include <stdio.h> #include <math.h> // 定义常量 #define pi 3.14159 #define fs 1000 // 采样频率 #define fc 50 // 陷波器中心频率 #define bw 10 // 陷波器带宽 #define q 1.0 / sqrt(2) // 品质因数 // 定义函数 double IIR_trap(double x); // 主函数 int main() { double x, y; while(1) { printf("请输入需要滤波的信号:"); scanf("%lf", &x); y = IIR_trap(x); printf("滤波后的信号为:%lf\n", y); } return 0; } // IIR陷波器函数实现 double IIR_trap(double x) { static double x_buf[3] = {0.0, 0.0, 0.0}; static double y_buf[2] = {0.0, 0.0}; double a1, a2, b1, b2; double w0 = 2 * pi * fc / fs; double alpha = sin(w0) / (2 * q); double cosw0 = cos(w0); // 计算差分方程系数 a1 = -2 * cosw0; a2 = 1 - alpha; b1 = 1 - cosw0 * alpha; b2 = alpha; // 实现差分方程 double y = b2 * x + b1 * x_buf[0] + b2 * x_buf[1] - a1 * y_buf[0] - a2 * y_buf[1]; // 更新缓存 x_buf[1] = x_buf[0]; x_buf[0] = x; y_buf[1] = y_buf[0]; y_buf[0] = y; return y; } ``` 在此示例中,我们使用了IIR陷波器(Infinite Impulse Response Trap Filter)来实现滤波器。该函数接受一个输入信号x,并返回一个输出信号y。函数中使用了静态变量来存储输入和输出的缓存,以及系数a和b。差分方程的实现使用了一个for循环,并使用了缓存和系数来计算输出信号y。最后,函数更新了缓存,并返回输出信号y。 这只是一个简单的IIR陷波器的实现示例,实际应用中需要根据具体的需求来选择滤波器类型和参数等,以达到更好的滤波效果。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

KPer_Yang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值