莫烦强化学习笔记整理(十)Asynchronous Advantage Actor-Critic (A3C)

莫烦强化学习笔记整理(十)Asynchronous Advantage Actor-Critic (A3C)


链接: A3C代码.

1、A3C 要点

A3C是Google DeepMind 提出的一种解决 Actor-Critic 不收敛问题的算法。它会创建多个并行的环境, 让多个拥有副结构的 agent 同时在这些并行环境上更新主结构中的参数,并行中的 agent 们互不干扰, 而主结构的参数更新受到副结构提交更新的不连续性干扰, 所以更新的相关性被降低, 收敛性提高。

2、A3C 算法

actor与critic网络

# 这个 class 可以被调用生成一个 global net.
# 也能被调用生成一个 worker 的 net, 因为他们的结构是一样的,
# 所以这个 class 可以被重复利用.
class ACNet(object):
    def __init__(self, globalAC=None):
        # 当创建 worker 网络的时候, 我们传入之前创建的 globalAC 给这个 worker
        if 这是 global:   # 判断当下建立的网络是 local 还是 global
            with tf.variable_scope('Global_Net'):
                self._build_net()
        else:
            with tf.variable_scope('worker'):
                self._build_net()

            # 接着计算 critic loss 和 actor loss
            # 用这两个 loss 计算要推送的 gradients

            with tf.name_scope('sync'):  # 同步
                with tf.name_scope('pull'):
                    # 更新去 global
                with tf.name_scope('push'):
                    # 获取 global 参数

    def _build_net(self):
        # 在这里搭建 Actor 和 Critic 的网络
        return 均值, 方差, state_value

    def update_global(self, feed_dict):
        # 进行 push 操作

    def pull_global(self):
        # 进行 pull 操作

    def choose_action(self, s):
        # 根据 s 选动作

单个worker

class Worker(object):
    def __init__(self, name, globalAC):
        self.env = gym.make(GAME).unwrapped # 创建自己的环境
        self.name = name    # 自己的名字
        self.AC = ACNet(name, globalAC) # 自己的 local net, 并绑定上 globalAC

    def work(self):
        # s, a, r 的缓存, 用于 n_steps 更新
        buffer_s, buffer_a, buffer_r = [], [], []
        while not COORD.should_stop() and GLOBAL_EP < MAX_GLOBAL_EP:
            s = self.env.reset()

            for ep_t in range(MAX_EP_STEP):
                a = self.AC.choose_action(s)
                s_, r, done, info = self.env.step(a)

                buffer_s.append(s)  # 添加各种缓存
                buffer_a.append(a)
                buffer_r.append(r)

                # 每 UPDATE_GLOBAL_ITER 步 或者回合完了, 进行 sync 操作
                if total_step % UPDATE_GLOBAL_ITER == 0 or done:
                    # 获得用于计算 TD error 的 下一 state 的 value
                    if done:
                        v_s_ = 0   # terminal
                    else:
                        v_s_ = SESS.run(self.AC.v, {self.AC.s: s_[np.newaxis, :]})[0, 0]

                    buffer_v_target = []    # 下 state value 的缓存, 用于算 TD
                    for r in buffer_r[::-1]:    # 进行 n_steps forward view
                        v_s_ = r + GAMMA * v_s_
                        buffer_v_target.append(v_s_)
                    buffer_v_target.reverse()

                    buffer_s, buffer_a, buffer_v_target = np.vstack(buffer_s), np.vstack(buffer_a), np.vstack(buffer_v_target)

                    feed_dict = {
                        self.AC.s: buffer_s,
                        self.AC.a_his: buffer_a,
                        self.AC.v_target: buffer_v_target,
                    }

                    self.AC.update_global(feed_dict)    # 推送更新去 globalAC
                    buffer_s, buffer_a, buffer_r = [], [], []   # 清空缓存
                    self.AC.pull_global()   # 获取 globalAC 的最新参数

                s = s_
                if done:
                    GLOBAL_EP += 1  # 加一回合
                    break   # 结束这回合

worker并行工作

在这里插入图片描述

with tf.device("/cpu:0"):
    GLOBAL_AC = ACNet(GLOBAL_NET_SCOPE)  # 建立 Global AC
    workers = []
    for i in range(N_WORKERS):  # 创建 worker, 之后在并行
        workers.append(Worker(GLOBAL_AC))   # 每个 worker 都有共享这个 global AC

COORD = tf.train.Coordinator()  # Tensorflow 用于并行的工具

worker_threads = []
for worker in workers:
    job = lambda: worker.work()
    t = threading.Thread(target=job)    # 添加一个工作线程
    t.start()
    worker_threads.append(t)
COORD.join(worker_threads)  # tf 的线程调度
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值