Atari Pong
简介
Pong是起源于1972年美国的一款模拟两个人打乒乓球的游戏,近几年常用于测试强化学习算法的性能。这篇文章主要记录如何用DQN实现玩Atari游戏中的Pong,希望对和我一样的小白有所帮助,文章最后附本文代码及参考代码。
环境介绍
torch = 1.8.0+cu111
Python = 3.8.5
环境配置见另一篇博客https://blog.csdn.net/libenfan/article/details/116396388?spm=1001.2014.3001.5502
代码详解
代码主要包含四个部分:经验重放区ReplayMemory、DQN网络、DQNagent、训练器Trainer。
ReplayMemory
用于DQN的经验重放,包含的采样、存储、计算经验重放区长度三个方法。
# 定义一个元组表征经验存储的格式
Transition = namedtuple('Transion',
('state', 'action', 'next_state', 'reward'))
class ReplayMemory(object):
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
def push(self, *args):
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = Transition(*args)
self.position = (self.position + 1) % self.capacity #移动指针,经验池满了之后从最开始的位置开始将最近的经验存进经验池
def sample(self, batch_size):
return random.sample(self.memory, batch_size)# 从经验池中随机采样
def __len__(self):
return len(self.memory)
DQN网络
三层卷积层,两层线性连接层,(这里要注意卷积层输出的大小要能够与线性层的输入大小相匹配)。
由于需要对pong环境进行重写,因此DQN网络的输入大小在后面介绍pong环境重写的时候会说明。
class DQN(nn.Module):
def __init__(self, in_channels=4, n_actions=14):
"""
Initialize Deep Q Network
Args:
in_channels (int): number of input channels
n_actions (int): number of outputs
"""
super(DQN, self).__init__()
self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
# self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
# self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
# self.bn3 = nn.BatchNorm2d(64)
self.fc4 = nn.Linear(7 * 7 * 64, 512)
self.head = nn.Linear(512, n_actions)
def forward(self, x):
x = x.