深度强化学习方法(DQN)玩转Atari游戏(pong)

本文详细介绍了如何使用深度强化学习(DQN)算法玩转Atari游戏Pong,包括环境介绍、DQN网络结构、经验重放区、智能体DQNagent的实现以及训练过程。通过代码解析,展示了从ReplayMemory到训练器Trainer的完整流程,适合初学者了解DQN在游戏控制中的应用。
摘要由CSDN通过智能技术生成

Atari Pong

简介

        Pong是起源于1972年美国的一款模拟两个人打乒乓球的游戏,近几年常用于测试强化学习算法的性能。这篇文章主要记录如何用DQN实现玩Atari游戏中的Pong,希望对和我一样的小白有所帮助,文章最后附本文代码及参考代码。

环境介绍

torch = 1.8.0+cu111
Python = 3.8.5
环境配置见另一篇博客https://blog.csdn.net/libenfan/article/details/116396388?spm=1001.2014.3001.5502

代码详解

        代码主要包含四个部分:经验重放区ReplayMemory、DQN网络、DQNagent、训练器Trainer。

ReplayMemory

        用于DQN的经验重放,包含的采样、存储、计算经验重放区长度三个方法。

# 定义一个元组表征经验存储的格式
Transition = namedtuple('Transion', 
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0
     
    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity #移动指针,经验池满了之后从最开始的位置开始将最近的经验存进经验池
        
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)# 从经验池中随机采样
    
    def __len__(self):
        return len(self.memory)

DQN网络

        三层卷积层,两层线性连接层,(这里要注意卷积层输出的大小要能够与线性层的输入大小相匹配)。
        由于需要对pong环境进行重写,因此DQN网络的输入大小在后面介绍pong环境重写的时候会说明。

class DQN(nn.Module):
    def __init__(self, in_channels=4, n_actions=14):
        """
        Initialize Deep Q Network
        Args:
            in_channels (int): number of input channels
            n_actions (int): number of outputs
        """
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
        # self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        # self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        # self.bn3 = nn.BatchNorm2d(64)
        self.fc4 = nn.Linear(7 * 7 * 64, 512)
        self.head = nn.Linear(512, n_actions)
        
    def forward(self, x):
        x = x.
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

libenfan

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值