深度强化学习方法(DQN)玩转Atari游戏(pong)

Atari Pong

简介

        Pong是起源于1972年美国的一款模拟两个人打乒乓球的游戏,近几年常用于测试强化学习算法的性能。这篇文章主要记录如何用DQN实现玩Atari游戏中的Pong,希望对和我一样的小白有所帮助,文章最后附本文代码及参考代码。

环境介绍

torch = 1.8.0+cu111
Python = 3.8.5
环境配置见另一篇博客https://blog.csdn.net/libenfan/article/details/116396388?spm=1001.2014.3001.5502

代码详解

        代码主要包含四个部分:经验重放区ReplayMemory、DQN网络、DQNagent、训练器Trainer。

ReplayMemory

        用于DQN的经验重放,包含的采样、存储、计算经验重放区长度三个方法。

# 定义一个元组表征经验存储的格式
Transition = namedtuple('Transion', 
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0
     
    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity #移动指针,经验池满了之后从最开始的位置开始将最近的经验存进经验池
        
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)# 从经验池中随机采样
    
    def __len__(self):
        return len(self.memory)

DQN网络

        三层卷积层,两层线性连接层,(这里要注意卷积层输出的大小要能够与线性层的输入大小相匹配)。
        由于需要对pong环境进行重写,因此DQN网络的输入大小在后面介绍pong环境重写的时候会说明。

class DQN(nn.Module):
    def __init__(self, in_channels=4, n_actions=14):
        """
        Initialize Deep Q Network
        Args:
            in_channels (int): number of input channels
            n_actions (int): number of outputs
        """
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
        # self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        # self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        # self.bn3 = nn.BatchNorm2d(64)
        self.fc4 = nn.Linear(7 * 7 * 64, 512)
        self.head = nn.Linear(512, n_actions)
        
    def forward(self, x):
        x = x.
  • 9
    点赞
  • 86
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 12
    评论
PyTorch是一个开源的Python机器学习库,它提供了强大的工具来进行深度学习强化学习。在这篇文章中,我们将使用PyTorch来构建一个深度强化学习模型,让AIAtari游戏Atari游戏是一系列经典的电子游戏,如Pong、Space Invaders和Breakout。这些游戏简单易懂,但是对于人类家来说仍然有挑战性。我们将使用Atari游戏作为我们的强化学习环境,以训练我们的AI代理。 我们将使用Deep Q-Networks(DQN)算法来训练我们的AI代理。DQN是一种基于深度学习强化学习算法,它将神经网络与Q学习相结合,使得AI代理可以学习如何最大化其预期回报。 首先,我们需要安装PyTorch和OpenAI Gym。OpenAI Gym是一个用于开发和比较强化学习算法的工具包。您可以在这里找到有关安装方法的说明:https://pytorch.org/get-started/locally/ 和 https://gym.openai.com/docs/#installation。 在安装完成后,我们可以开始编写我们的代码。 首先,我们需要导入必要的库: ```python import random import math import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import numpy as np import gym ``` 接下来,我们定义我们的Agent类。Agent类负责与环境交互并学习如何游戏。 ```python class Agent: def __init__(self, env, gamma, epsilon, lr): self.env = env self.gamma = gamma self.epsilon = epsilon self.lr = lr self.memory = [] self.model = Net(env.observation_space.shape[0], env.action_space.n) self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) def act(self, state): if random.random() < self.epsilon: return self.env.action_space.sample() else: state = torch.FloatTensor(state).unsqueeze(0) q_values = self.model(state) return q_values.max(1)[1].item() def remember(self, state, action, next_state, reward): self.memory.append((state, action, next_state, reward)) def learn(self, batch_size): if len(self.memory) < batch_size: return transitions = random.sample(self.memory, batch_size) batch = Transition(*zip(*transitions)) state_batch = torch.FloatTensor(batch.state) action_batch = torch.LongTensor(batch.action) reward_batch = torch.FloatTensor(batch.reward) next_state_batch = torch.FloatTensor(batch.next_state) q_values = self.model(state_batch).gather(1, action_batch.unsqueeze(1)) next_q_values = self.model(next_state_batch).max(1)[0].detach() expected_q_values = (next_q_values * self.gamma) + reward_batch loss = F.smooth_l1_loss(q_values, expected_q_values.unsqueeze(1)) self.optimizer.zero_grad() loss.backward() self.optimizer.step() ``` 我们的Agent类具有几个方法: 1. `__init__`方法初始化代理。我们传递的参数包括环境,折扣因子(gamma),ε贪心策略中的ε值和学习率(lr)。我们还创建了一个神经网络模型和Adam优化器。 2. `act`方法根据当前状态选择一个动作。我们使用ε贪心策略,在一定概率下随机选择动作,否则选择当前状态下具有最高Q值的动作。 3. `remember`方法将经验元组(state,action,next_state,reward)添加到内存中。 4. `learn`方法从内存中随机选择一批经验元组,然后使用这些经验元组进行训练。我们计算当前状态下的Q值和下一个状态下的最大Q值,然后使用这些值计算预期Q值。我们使用平滑L1损失函数计算损失,并使用Adam优化器更新我们的模型。 接下来,我们定义我们的神经网络模型。 ```python class Net(nn.Module): def __init__(self, input_size, output_size): super(Net, self).__init__() self.fc1 = nn.Linear(input_size, 128) self.fc2 = nn.Linear(128, 128) self.fc3 = nn.Linear(128, output_size) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x ``` 我们的模型是一个简单的前馈神经网络,具有三个全连接层。我们使用ReLU激活函数,并且输出层的大小等于动作空间的大小。 最后,我们定义我们的主函数,用于实际运行我们的代理。 ```python if __name__ == '__main__': env = gym.make('Breakout-v0') agent = Agent(env, gamma=0.99, epsilon=1.0, lr=1e-4) batch_size = 32 num_episodes = 1000 for i_episode in range(num_episodes): state = env.reset() total_reward = 0 done = False while not done: action = agent.act(state) next_state, reward, done, _ = env.step(action) agent.remember(state, action, next_state, reward) agent.learn(batch_size) total_reward += reward state = next_state agent.epsilon = max(0.01, agent.epsilon * 0.995) print("Episode: {}, total reward: {}, epsilon: {}".format(i_episode, total_reward, agent.epsilon)) ``` 我们使用OpenAI Gym中的Breakout游戏来测试我们的代理。在每个训练周期中,我们重置环境并运行一个周期,直到游戏结束。我们将每个状态、动作、下一个状态和奖励作为经验元组传递给我们的Agent,并使用这些经验元组进行训练。我们使用逐步减小的ε值来平衡探索和利用。我们打印出每个训练周期的总奖励以及当前的ε值。 现在我们已经编写了我们的代码,我们可以开始训练我们的代理。运行主函数,我们将看到我们的代理在游戏中逐渐变得更加熟练。我们可以尝试调整参数来进一步优化我们的代理的性能。 总结: 在本文中,我们使用PyTorch和OpenAI Gym构建了一个深度强化学习代理,让它Atari游戏。我们使用Deep Q-Networks算法和ε贪心策略来训练我们的代理,并逐步减小ε值来平衡探索和利用。我们的代理在游戏中逐渐变得更加熟练,展示了PyTorch在深度强化学习中的强大功能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

libenfan

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值