项目背景
超绝无敌废话前言
经过嵌入式专栏的构建,现在也逐渐开始扩建其他领域的项目分享了,(当然,嵌入式的相关项目也会继续分享)我一下就想到了AI游戏这一类,我个人觉得这类项目还挺有意思的,大家也可以当作娱乐的项目去做做实验。之后呢,我就会在这个专栏分享一些我收集到的项目或者是我自己改造过的项目。
话不多说,先介绍一个大部分计算机专业,尤其是接触强化学习这方面同学多少熟悉的一个项目——AI飞鸟。当然啦,不管什么项目都没办法一口吃成一个胖子,所以我们就慢慢地建立一个完整的项目。
背景简介
现在随着人工智能技术的不断发展呢,强化学习在游戏领域的应用越来越广泛。这个项目的目的是利用Q-learning算法,结合Python编程语言,开发一个类似于FlappyBird的简单游戏,并通过Q-Learning智能体的自主学习来提高游戏得分。
项目包括利用使用Python的Pygame库或其他游戏开发框架,构建类似于FlappyBird的游戏界面和游戏逻辑。游戏场景包括管道、小鸟、背景、对应动作的音效等元素。然后进行Q-learning算法实现:根据Q-learning算法原理,编写智能体的学习过程,包括状态空间和动作空间的定义、Q函数的建立、经验回放的实现、探索策略的选择等。之后再进行智能体的训练,将Q-learning算法应用于游戏智能体的训练中(可以切换使用基本智能体还是带有贪心策略的智能体),使智能体通过不断与游戏环境的交互,学习最优策略,提高游戏得分。最后在此基础上,进行性能分析:收集智能体在不同训练轮次下的游戏得分数据,分析Q-learning在游戏开发中的应用效果。对比不同参数设置对智能体性能的影响,优化算法参数。
Q-learning算法
Q-learning是一种无模型的强化学习算法,它旨在学习一个策略,该策略能够告诉一个代理(agent)在给定状态下应该采取哪个动作以最大化其预期的累积奖励。Q-learning不需要模型就可以学习,这意味着它不需要知道环境的动态(即状态转移概率)。
Q-learning的核心思想是学习一个动作价值函数Q(s, a),它估计在状态s下采取动作a并遵循最优策略所能获得的预期回报。算法的目标是找到使Q值最大化的策略。
以下是Q-learning算法的基本步骤:
-
初始化:为每个状态-动作对(s, a)初始化Q值,通常初始化为0。
-
观察:代理观察当前状态s。
-
选择和执行动作:代理根据当前的Q值表选择一个动作a。通常,这涉及到使用ε-贪心策略,即以ε的概率随机选择一个动作,以1-ε的概率选择当前Q值最高的动作。
-
奖励和下一状态:代理执行动作a,环境给予奖励r和下一个状态s’。
-
Q值更新:代理根据奖励和下一个状态更新Q值。更新公式如下:
其中:
α是学习率,决定了新信息覆盖旧信息的速度。
γ是折扣因子,决定了未来奖励的重要性。
是在新状态s’下所有可能动作的最大Q值。
-
重复:代理将新状态s’作为当前状态,并重复步骤3-5,直到达到终止状态或满足终止条件。
Q-learning算法的几个关键点:
- 探索与利用:算法需要在探索(尝试新动作以发现更好的策略)和利用(使用已知的最佳策略)之间找到平衡。
- 收敛性:在某些条件下,Q-learning可以收敛到最优策略,即使在非有限或非确定性的环境中也是如此。
- 离线学习:Q-learning是一种离线学习算法,因为它可以从未被执行的动作中学习,只要这些动作的Q值被更新。
Q-learning是强化学习领域中非常基础且广泛使用的算法之一,而且还为更复杂的算法如深度Q网络(DQN)等奠定了基础。
项目结构
一个完整的项目都会有一个总体的项目框架
项目根目录
│
├── checkpoints/ # 可能会用于存放训练模型的检查点,如果是机器学习项目
│
├── demonstration/ # 可能包含演示视频或演示版本的游戏,用于展示项目成果
│
├── modules/ # 这个文件夹可能包含项目的主要模块或组件,具体功能需要查看里面的内容
│
├── resources/ # 通常用于存放项目资源文件,比如图像、音频、配置文件等
│
├── cfg.py # 配置文件,可能包含游戏或应用的配置参数
│
├── flappybird.py # 这个文件可能是游戏的主要逻辑文件,或者是游戏的入口文件
│
├── FlappyBird游戏开发.doc # 文档文件,可能包含游戏开发的说明、设计文档或教程
│
└── requirements.txt # 项目依赖文件,列出了所有需要安装的Python包或其他依赖项
当然,这项目结构没啥固定标准,就看具体要求是啥了,但是重要的是,得让自己和用户清楚地了解到什么地方该存储什么文件。
这个就是游戏运行结果,所以,最重要的准备工作就是这个界面里大家所能看到的所有图片以及一些音效都要准备好,我的工程里都已经有了,大家直接下载就好,而且这个项目不是稀奇项目吗,所以资源还是很好找的。
场景和智能体的搭建
这个应该是这类项目比较关键的地方。说实话我不知道该怎么介绍这些代码,先试试吧,希望大家多多给我提建议。
游戏的相关设置
作为一个游戏来说,当然必不可少的就是相关的设置,不然就是乱七八糟的。
'''config file'''
import os
# FPS
FPS = 45
#屏幕大小
SCREENWIDTH = 288
SCREENHEIGHT = 512
# 管道之间的间隙
PIPE_GAP_SIZE = 100
# 游戏图片路径
NUMBER_IMAGE_PATHS = {
'0': os.path.join(os.getcwd(), 'resources/images/0.png'),
'1': os.path.join(os.getcwd(), 'resources/images/1.png'),
'2': os.path.join(os.getcwd(), 'resources/images/2.png'),
'3': os.path.join(os.getcwd(), 'resources/images/3.png'),
'4': os.path.join(os.getcwd(), 'resources/images/4.png'),
'5': os.path.join(os.getcwd(), 'resources/images/5.png'),
'6': os.path.join(os.getcwd(), 'resources/images/6.png'),
'7': os.path.join(os.getcwd(), 'resources/images/7.png'),
'8': os.path.join(os.getcwd(), 'resources/images/8.png'),
'9': os.path.join(os.getcwd(), 'resources/images/9.png')
}
BIRD_IMAGE_PATHS = {
'red': {'up': os.path.join(os.getcwd(), 'resources/images/redbird-upflap.png'),
'mid': os.path.join(os.getcwd(), 'resources/images/redbird-midflap.png'),
'down': os.path.join(os.getcwd(), 'resources/images/redbird-downflap.png')},
'blue': {'up': os.path.join(os.getcwd(), 'resources/images/bluebird-upflap.png'),
'mid': os.path.join(os.getcwd(), 'resources/images/bluebird-midflap.png'),
'down': os.path.join(os.getcwd(), 'resources/images/bluebird-downflap.png')},
'yellow': {'up': os.path.join(os.getcwd(), 'resources/images/yellowbird-upflap.png'),
'mid': os.path.join(os.getcwd(), 'resources/images/yellowbird-midflap.png'),
'down': os.path.join(os.getcwd(), 'resources/images/yellowbird-downflap.png')}
}
BACKGROUND_IMAGE_PATHS = {
'day': os.path.join(os.getcwd(), 'resources/images/background-day.png'),
'night': os.path.join(os.getcwd(), 'resources/images/background-night.png')
}
PIPE_IMAGE_PATHS = {
'green': os.path.join(os.getcwd(), 'resources/images/pipe-green.png'),
'red': os.path.join(os.getcwd(), 'resources/images/pipe-red.png')
}
OTHER_IMAGE_PATHS = {
'gameover': os.path.join(os.getcwd(), 'resources/images/gameover.png'),
'message': os.path.join(os.getcwd(), 'resources/images/message.png'),
'base': os.path.join(os.getcwd(), 'resources/images/base.png')
}
#媒体环境
AUDIO_PATHS = {
'die': os.path.join(os.getcwd(), 'resources/audios/die.wav'),
'hit': os.path.join(os.getcwd(), 'resources/audios/hit.wav'),
'point': os.path.join(os.getcwd(), 'resources/audios/point.wav'),
'swoosh': os.path.join(os.getcwd(), 'resources/audios/swoosh.wav'),
'wing': os.path.join(os.getcwd(), 'resources/audios/wing.wav')
}
这段代码是一个Python配置文件,用于设置一个游戏的各种参数和资源路径。
-
FPS (Frames Per Second): 设置游戏的帧率为45,这意味着游戏每秒会渲染45帧。
-
屏幕大小: 定义了游戏屏幕的宽度(288像素)和高度(512像素)。
-
管道间隙: 设置了游戏中管道之间的间隙大小为100像素。
-
数字图片路径: 定义了一个字典,包含了数字0到9的图片路径,这些图片可能用于显示分数。
-
鸟类图片路径: 定义了一个字典,包含了不同颜色鸟类(红色、蓝色、黄色)的三种状态(向上、中间、向下)的图片路径。
-
背景图片路径: 定义了一个字典,包含了白天和夜晚背景的图片路径。
-
管道图片路径: 定义了一个字典,包含了绿色和红色管道的图片路径。
-
其他图片路径: 定义了一个字典,包含了游戏结束画面、消息提示和基座的图片路径。
-
音频路径: 定义了一个字典,包含了游戏中可能用到的各种音效文件路径,如死亡、撞击、得分、管道声音和翅膀声音。
代码中使用了os.path.join
和os.getcwd()
来构建资源文件的完整路径,这样可以确保无论代码在哪个目录下运行,都能找到正确的资源文件。这样的配置文件有助于游戏开发过程中对资源的管理和使用。
小鸟类的定义
import pygame
import itertools
'''小鸟的类'''
class Bird(pygame.sprite.Sprite):
def __init__(self, images, idx, position, **kwargs):
pygame.sprite.Sprite.__init__(self)
self.images = images
self.image = list(images.values())[idx]
self.rect = self.image.get_rect()
self.mask = pygame.mask.from_surface(self.image)
self.rect.left, self.rect.top = position
#垂直移动所需的变量
self.is_flapped = False
self.speed = -9
#鸟类状态开关所需的变量
self.bird_idx = idx
self.bird_idx_cycle = itertools.cycle([0, 1, 2, 1])
self.bird_idx_change_count = 0
'''更新小鸟'''
def update(self, boundary_values):
#垂直更新位置
if not self.is_flapped:
self.speed = min(self.speed+1, 10)
self.is_flapped = False
self.rect.top += self.speed
# 确定这只鸟是否是因为撞到了上下边界而死亡
is_dead = False
if self.rect.bottom > boundary_values[1]:
is_dead = True
self.rect.bottom = boundary_values[1]
if self.rect.top < boundary_values[0]:
is_dead = True
self.rect.top = boundary_values[0]
# 模拟翅膀的振动
self.bird_idx_change_count += 1
if self.bird_idx_change_count % 3 == 0:
self.bird_idx = next(self.bird_idx_cycle)
self.image = list(self.images.values())[self.bird_idx]
self.bird_idx_change_count = 0
return is_dead
'''设置飞行模型'''
def setFlapped(self):
self.is_flapped = True
self.speed = -9
'''绑定到屏幕'''
def draw(self, screen):
screen.blit(self.image, self.rect)
这段代码是使用Pygame库编写的,定义了一个名为Bird
的类,用于表示游戏中的小鸟。以下是代码的主要功能和相关知识点:
-
类继承:
Bird
类继承自Pygame的Sprite
类,这使得小鸟可以作为游戏中的一个精灵(sprite)进行管理。 -
初始化方法:
__init__
方法初始化小鸟对象,包括设置小鸟的图片、位置、速度、飞行状态等。 -
图片和动画: 小鸟的图片存储在一个字典
images
中,通过索引idx
选择特定的图片。小鸟的翅膀振动通过在不同图片之间循环切换来模拟,这是通过itertools.cycle
实现的,它创建了一个无限循环的迭代器。 -
位置和速度: 小鸟的垂直位置由
self.rect.top
控制,速度self.speed
用于控制小鸟的上升和下降。每次调用update
方法时,小鸟的位置会根据速度更新。 -
碰撞检测:
update
方法还负责检测小鸟是否撞到了游戏的上下边界。如果小鸟超出了边界,is_dead
变量会被设置为True
,表示小鸟死亡。 -
飞行动作:
setFlapped
方法用于模拟小鸟的飞行动作。当小鸟被点击或按键触发时,会调用此方法,设置小鸟的初始上升速度。 -
绘制小鸟:
draw
方法将小鸟的当前图像绘制到屏幕上。screen.blit
是Pygame中用于将图像绘制到屏幕表面的方法。 -
碰撞处理: 如果小鸟的位置超出了预设的边界值,
update
方法会将is_dead
设置为True
,这可以用来在游戏中触发小鸟死亡的逻辑。 -
相关知识点:
-
Pygame库: Pygame是一个跨平台的Python模块,专门用于编写视频游戏,包括图形和声音库。
-
精灵(Sprite): 在Pygame中,精灵是游戏中的一个图形对象,可以独立于其他对象进行移动和绘制。
-
迭代器(Iterator):
itertools.cycle
创建了一个迭代器,它可以无限循环地返回给定的序列。 -
碰撞检测: 在游戏开发中,碰撞检测是判断两个或多个对象是否接触的重要机制。
-
图像处理: Pygame提供了多种图像处理功能,如加载图像、图像转换、绘制图像等。
-
事件处理: Pygame通过事件循环来处理用户的输入,如键盘按键、鼠标点击等。
这段代码是游戏开发中的一个典型例子,展示了如何使用Pygame库来创建和管理游戏中的对象。
管道类的定义
import random
import pygame
'''管道类'''
class Pipe(pygame.sprite.Sprite):
def __init__(self, image, position, type_, **kwargs):
pygame.sprite.Sprite.__init__(self)
self.image = image
self.rect = self.image.get_rect()
self.mask = pygame.mask.from_surface(self.image)
self.rect.left, self.rect.top = position
self.type_ = type_
self.used_for_score = False
@staticmethod
def randomPipe(cfg, image):
base_y = 0.79 * cfg.SCREENHEIGHT
up_y = int(base_y * 0.2) + random.randrange(0, int(base_y * 0.6 - cfg.PIPE_GAP_SIZE))
return {'top': (cfg.SCREENWIDTH+10, up_y-image.get_height()), 'bottom': (cfg.SCREENWIDTH+10, up_y+cfg.PIPE_GAP_SIZE)}
这段代码定义了一个名为Pipe
的类,用于表示游戏中的管道。以下是代码的主要功能和相关知识点:
-
类继承:
Pipe
类继承自Pygame的Sprite
类,使得管道可以作为游戏中的一个精灵进行管理。 -
初始化方法:
__init__
方法初始化管道对象,包括设置管道的图片、位置、类型等。 -
图片和位置: 管道的图片通过
image
参数传入,位置通过position
参数传入。self.rect
是管道图像的矩形区域,用于碰撞检测和绘制。 -
类型:
self.type_
用于区分管道的类型,例如可能是“上”或“下”管道。 -
分数相关:
self.used_for_score
是一个布尔值,用于标记这个管道是否已经被用来计算分数。 -
随机管道生成:
randomPipe
是一个静态方法,用于生成随机位置的管道。它根据屏幕高度和管道间隙大小来计算上管道和下管道的Y坐标。 -
随机数: 使用
random.randrange
生成一个随机数,用于确定上管道的Y坐标,从而使得管道的位置每次游戏时都是随机的。
这段代码是游戏开发中的一个典型例子,展示了如何使用Pygame库来创建和管理游戏中的静态对象(如管道)。通过随机生成管道的位置,增加了游戏的挑战性和不可预测性。
开始和结束画面的定义
import sys
import pygame
import itertools
'''游戏开始界面'''
def startGame(screen, sounds, bird_images, other_images, backgroud_image, cfg, mode):
base_pos = [0, cfg.SCREENHEIGHT*0.79]
base_diff_bg = other_images['base'].get_width() - backgroud_image.get_width()
msg_pos = [(cfg.SCREENWIDTH-other_images['message'].get_width())/2, cfg.SCREENHEIGHT*0.12]
bird_idx = 0
bird_idx_change_count = 0
bird_idx_cycle = itertools.cycle([0, 1, 2, 1])
bird_pos = [cfg.SCREENWIDTH*0.2, (cfg.SCREENHEIGHT-list(bird_images.values())[0].get_height())/2]
bird_y_shift_count = 0
bird_y_shift_max = 9
shift = 1
clock = pygame.time.Clock()
if mode == 'train':
return {'bird_pos': bird_pos, 'base_pos': base_pos, 'bird_idx': bird_idx}
while True:
for event in pygame.event.get():
if event.type == pygame.QUIT or (event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE):
pygame.quit()
sys.exit()
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_SPACE or event.key == pygame.K_UP:
return {'bird_pos': bird_pos, 'base_pos': base_pos, 'bird_idx': bird_idx}
sounds['wing'].play()
bird_idx_change_count += 1
if bird_idx_change_count % 5 == 0:
bird_idx = next(bird_idx_cycle)
bird_idx_change_count = 0
base_pos[0] = -((-base_pos[0] + 4) % base_diff_bg)
bird_y_shift_count += 1
if bird_y_shift_count == bird_y_shift_max:
bird_y_shift_max = 16
shift = -1 * shift
bird_y_shift_count = 0
bird_pos[-1] = bird_pos[-1] + shift
screen.blit(backgroud_image, (0, 0))
screen.blit(list(bird_images.values())[bird_idx], bird_pos)
screen.blit(other_images['message'], msg_pos)
screen.blit(other_images['base'], base_pos)
pygame.display.update()
clock.tick(cfg.FPS)
import sys
import pygame
'''游戏结束界面'''
def endGame(screen, sounds, showScore, score, number_images, bird, pipe_sprites, backgroud_image, other_images, base_pos, cfg, mode):
if mode == 'train':
return
sounds['die'].play()
clock = pygame.time.Clock()
while True:
for event in pygame.event.get():
if event.type == pygame.QUIT or (event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE):
pygame.quit()
sys.exit()
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_SPACE or event.key == pygame.K_UP:
return
boundary_values = [0, base_pos[-1]]
bird.update(boundary_values)
screen.blit(backgroud_image, (0, 0))
pipe_sprites.draw(screen)
screen.blit(other_images['base'], base_pos)
showScore(screen, score, number_images)
bird.draw(screen)
pygame.display.update()
clock.tick(cfg.FPS)
代码主要定义了一个名为startGame(以及endGame)
的函数,用于处理游戏开始界面的逻辑和显示。以下是代码的主要功能和相关知识点:
-
初始化变量:
base_pos
: 基座(地面)的位置。msg_pos
: 消息(如“按空格开始”)的位置。bird_idx
: 小鸟动画的当前索引。bird_idx_change_count
: 用于控制小鸟动画的计数器。bird_idx_cycle
: 小鸟动画的循环迭代器。bird_pos
: 小鸟的位置。bird_y_shift_count
: 用于控制小鸟上下移动的计数器。bird_y_shift_max
: 小鸟上下移动的最大次数。shift
: 小鸟上下移动的方向。clock
: 用于控制游戏帧率的时钟对象。
-
模式判断:
- 如果
mode
是'train'
,则直接返回小鸟的初始位置和索引。
- 如果
-
事件处理:
- 检查是否有退出事件(如关闭窗口或按下Esc键)。
- 检查是否有按键按下事件,如果按下空格或上箭头,则返回游戏开始的参数。
-
小鸟动画:
- 通过
bird_idx_change_count
控制小鸟动画的切换,每5帧切换一次小鸟的图像。
- 通过
-
背景滚动:
- 通过改变
base_pos[0]
的值来模拟背景滚动的效果。
- 通过改变
-
小鸟上下移动:
- 通过
bird_y_shift_count
和shift
控制小鸟的上下移动,每达到一定次数后改变移动方向。
- 通过
-
绘制界面:
- 将背景、小鸟、消息和基座绘制到屏幕上。
- 使用
pygame.display.update()
更新显示。
-
帧率控制:
- 使用
clock.tick(cfg.FPS)
控制游戏的帧率。
- 使用
相关知识点:
-
Pygame事件处理: Pygame通过事件循环来处理用户输入和系统事件,如键盘按键、鼠标点击等。
-
动画: 通过在不同图像之间切换来实现动画效果。
-
背景滚动: 通过改变背景图像的位置来模拟滚动效果。
-
帧率控制: 使用Pygame的时钟对象来控制游戏的帧率,确保游戏运行的流畅性。
-
图像绘制: 使用
blit
方法将图像绘制到Pygame的屏幕对象上。 -
迭代器: 使用
itertools.cycle
创建一个无限循环的迭代器,用于控制动画的循环。 -
条件判断: 使用
if
语句进行条件判断,如检查模式、处理事件等。
这段代码也是游戏开发中的一个典型例子,结束画面的内容和这个类似。展示了如何使用Pygame库来创建和管理游戏开始界面的逻辑和显示。通过处理用户输入、控制动画和绘制界面,为玩家提供了一个交互式的开始界面。
Q-Learning智能体的构建
这一部分是项目里最为关键的地方
import pickle
import random
import numpy as np
import matplotlib.pyplot as plt
'''q learning 智能体'''
class QLearningAgent():
def __init__(self, mode, **kwargs):
self.mode = mode
# 学习率
self.learning_rate = 0.7
# 折现参数(也叫折现率)
self.discount_factor = 0.95
# 储存必要的历史数据,参数为[previous_state, previous_action, state, reward]
self.history_storage = []
# 储存qvalues,最后一个维度是[value_for_do_nothing, value_for_flappy]
self.qvalues_storage = np.zeros((130, 130, 20, 2))
# 储存每集的分数
self.scores_storage = []
# 之前的状态
self.previous_state = []
# 0表示没有任何动作,1表示飞行
self.previous_action = 0
# 集的数量
self.num_episode = 0
# 到目前的最高得分
self.max_score = 0
'''做出决定'''
def act(self, delta_x, delta_y, bird_speed):
if not self.previous_state:
self.previous_state = [delta_x, delta_y, bird_speed]
return self.previous_action
if self.mode == 'train':
state = [delta_x, delta_y, bird_speed]
self.history_storage.append([self.previous_state, self.previous_action, state, 0])
self.previous_state = state
# 根据qvalues做出一个决定
if self.qvalues_storage[delta_x, delta_y, bird_speed][0] >= self.qvalues_storage[delta_x, delta_y, bird_speed][1]:
self.previous_action = 0
else:
self.previous_action = 1
return self.previous_action
'''设定奖赏'''
def setReward(self, reward):
if self.history_storage:
self.history_storage[-1][3] = reward
'''每一集之后更新qvalues_storage'''
def update(self, score, is_logging=True):
self.num_episode += 1
self.max_score = max(self.max_score, score)
self.scores_storage.append(score)
if is_logging:
print('Episode: %s, Score: %s, Max Score: %s' % (self.num_episode, score, self.max_score))
if self.mode == 'train':
history = list(reversed(self.history_storage))
# 惩罚碰撞前的最后一个num_penalization状态
num_penalization = 2
for item in history:
previous_state, previous_action, state, reward = item
if num_penalization > 0:
num_penalization -= 1
reward = -1000000
x_0, y_0, z_0 = previous_state
x_1, y_1, z_1 = state
self.qvalues_storage[x_0, y_0, z_0, previous_action] = (1 - self.learning_rate) * self.qvalues_storage[x_0, y_0, z_0, previous_action] +\
self.learning_rate * (reward + self.discount_factor * max(self.qvalues_storage[x_1, y_1, z_1]))
self.history_storage = []
'''保存模型'''
def saveModel(self, modelpath):
data = {
'num_episode': self.num_episode,
'max_score': self.max_score,
'scores_storage': self.scores_storage,
'qvalues_storage': self.qvalues_storage
}
with open(modelpath, 'wb') as f:
pickle.dump(data, f)
print('[INFO]: save checkpoints in %s...' % modelpath)
'''加载模型'''
def loadModel(self, modelpath):
print('[INFO]: load checkpoints from %s...' % modelpath)
with open(modelpath, 'rb') as f:
data = pickle.load(f)
self.num_episode = data.get('num_episode')
self.qvalues_storage = data.get('qvalues_storage')
'''带有ε-greedy策略的qlearning智能体'''
class QLearningGreedyAgent(QLearningAgent):
def __init__(self, mode, **kwargs):
super(QLearningGreedyAgent, self).__init__(mode, **kwargs)
self.epsilon = 0.1
self.epsilon_end = 0.0
self.epsilon_decay = 1e-5
'''做出决定'''
def act(self, delta_x, delta_y, bird_speed):
if not self.previous_state:
self.previous_state = [delta_x, delta_y, bird_speed]
return self.previous_action
if self.mode == 'train':
state = [delta_x, delta_y, bird_speed]
self.history_storage.append([self.previous_state, self.previous_action, state, 0])
self.previous_state = state
# 贪心策略
if random.random() <= self.epsilon:
self.previous_action = random.choice([0, 1])
else:
if self.qvalues_storage[delta_x, delta_y, bird_speed][0] >= self.qvalues_storage[delta_x, delta_y, bird_speed][1]:
self.previous_action = 0
else:
self.previous_action = 1
return self.previous_action
else:
super().act(delta_x, delta_y, bird_speed)
'''每一集之后更新qvalues_storage'''
def update(self, score, is_logging=True):
self.num_episode += 1
self.max_score = max(self.max_score, score)
self.scores_storage.append(score)
if is_logging:
print('Episode: %s, Epsilon: %s, Score: %s, Max Score: %s' % (self.num_episode, self.epsilon, score, self.max_score))
if self.mode == 'train':
history = list(reversed(self.history_storage))
# 惩罚崩溃前的最后一个num_penalization状态
num_penalization = 2
for item in history:
previous_state, previous_action, state, reward = item
if num_penalization > 0:
num_penalization -= 1
reward = -1000000
x_0, y_0, z_0 = previous_state
x_1, y_1, z_1 = state
self.qvalues_storage[x_0, y_0, z_0, previous_action] = (1 - self.learning_rate) * self.qvalues_storage[x_0, y_0, z_0, previous_action] +\
self.learning_rate * (reward + self.discount_factor * max(self.qvalues_storage[x_1, y_1, z_1]))
self.history_storage = []
if self.epsilon > self.epsilon_end:
self.epsilon -= self.epsilon_decay
'''保存模型'''
def saveModel(self, modelpath):
data = {
'num_episode': self.num_episode,
'max_score': self.max_score,
'scores_storage': self.scores_storage,
'qvalues_storage': self.qvalues_storage,
'epsilon': self.epsilon
}
with open(modelpath, 'wb') as f:
pickle.dump(data, f)
print('[INFO]: save checkpoints in %s...' % modelpath)
'''加载模型'''
def loadModel(self, modelpath):
print('[INFO]: load checkpoints from %s...' % modelpath)
with open(modelpath, 'rb') as f:
data = pickle.load(f)
self.num_episode = data.get('num_episode')
self.qvalues_storage = data.get('qvalues_storage')
self.epsilon = data.get('epsilon')
这段代码定义了两个类,QLearningAgent
和QLearningGreedyAgent
,它们都是使用Q学习算法来训练智能体在特定环境中(例如这个Flappy Bird游戏)进行决策的。
QLearningAgent
类
- 初始化: 设置学习率、折现因子、历史存储、Q值存储、分数存储等。
- 决策: 根据当前状态和Q值表来选择动作(0表示不操作,1表示操作)。
- 设置奖赏: 更新历史记录中的奖赏值。
- 更新Q值: 在每一集结束后,根据历史记录更新Q值表。
- 保存和加载模型: 使用pickle库保存和加载训练模型。
QLearningGreedyAgent
类
- 初始化: 继承
QLearningAgent
,并添加ε-greedy策略相关的参数,如ε(探索率)、ε_end(最小探索率)、ε_decay(探索率衰减)。 - 决策: 引入ε-greedy策略,以一定的概率随机选择动作,以一定的概率选择当前最优动作。
- 更新Q值: 与
QLearningAgent
类似,但增加了ε的更新逻辑。 - 保存和加载模型: 除了保存
QLearningAgent
的参数外,还保存和加载ε值。
相关知识点
- Q学习: 一种无模型的强化学习算法,通过与环境的交互来学习最优策略。
- ε-greedy策略: 在探索(随机选择动作)和利用(选择当前最优动作)之间进行权衡的策略。
- 学习率: 用于控制新学到的知识对原有知识的影响程度。
- 折现因子: 用于衡量未来奖励相对于即时奖励的重要性。
- 奖赏: 智能体在执行动作后从环境中获得的反馈。
- 状态: 智能体在环境中的当前情况。
- 动作: 智能体在给定状态下可以执行的操作。
- Q值: 表示在特定状态下执行特定动作的预期效用。
- 历史记录: 用于存储智能体的状态、动作、奖赏等信息,以便更新Q值。
- 模型保存和加载: 使用pickle库来保存训练过程中的参数,以便后续继续训练或评估。
这段代码展示了如何使用Q学习算法来训练一个智能体在特定环境中进行决策,并通过ε-greedy策略来平衡探索和利用。这就是机器学习中强化学习领域的一个典型应用。
主函数
最后呢,当然就是调用所有函数的主函数啦!!!也就是需要大家运行的主函数。
import os
import cfg
import sys
import random
import pygame
import argparse
from modules.sprites.Pipe import *
from modules.sprites.Bird import *
from modules.interfaces.endGame import *
from modules.interfaces.startGame import *
from modules.QLearningAgent.QLearningAgent import *
import matplotlib.pyplot as plt
'''解析参数'''
def parseArgs():
parser = argparse.ArgumentParser(description='Use q learning to play flappybird')
parser.add_argument('--mode', dest='mode', help='选择 <train> or <test> please', default='train', type=str)
parser.add_argument('--policy', dest='policy', help='选择 <plain> or <greedy> ', default='greedy', type=str)
args = parser.parse_args()
return args
'''初始化游戏'''
def initGame():
pygame.init()
pygame.mixer.init()
screen = pygame.display.set_mode((cfg.SCREENWIDTH, cfg.SCREENHEIGHT))
pygame.display.set_caption('Flappy Bird小游戏')
return screen
'''游戏画面显示'''
def showScore(screen, score, number_images):
digits = list(str(int(score)))
width = 0
for d in digits:
width += number_images.get(d).get_width()
offset = (cfg.SCREENWIDTH - width) / 2
for d in digits:
screen.blit(number_images.get(d), (offset, cfg.SCREENHEIGHT*0.1))
offset += number_images.get(d).get_width()
'''主函数定义'''
def main(mode, policy, agent, modelpath):
screen = initGame()
# 加载必要的游戏资源
# 加载游戏声音
sounds = dict()
for key, value in cfg.AUDIO_PATHS.items():
sounds[key] = pygame.mixer.Sound(value)
# 加载分数数字图片
number_images = dict()
for key, value in cfg.NUMBER_IMAGE_PATHS.items():
number_images[key] = pygame.image.load(value).convert_alpha()
# 管道
pipe_images = dict()
pipe_images['bottom'] = pygame.image.load(random.choice(list(cfg.PIPE_IMAGE_PATHS.values()))).convert_alpha()
pipe_images['top'] = pygame.transform.rotate(pipe_images['bottom'], 180)
# 小鸟图片
bird_images = dict()
for key, value in cfg.BIRD_IMAGE_PATHS[random.choice(list(cfg.BIRD_IMAGE_PATHS.keys()))].items():
bird_images[key] = pygame.image.load(value).convert_alpha()
# 背景图片
backgroud_image = pygame.image.load(random.choice(list(cfg.BACKGROUND_IMAGE_PATHS.values()))).convert_alpha()
# 其他图片
other_images = dict()
for key, value in cfg.OTHER_IMAGE_PATHS.items():
other_images[key] = pygame.image.load(value).convert_alpha()
# 游戏启动界面
game_start_info = startGame(screen, sounds, bird_images, other_images, backgroud_image, cfg, mode)
# 进入游戏主循环
score = 0
bird_pos, base_pos, bird_idx = list(game_start_info.values())
base_diff_bg = other_images['base'].get_width() - backgroud_image.get_width()
clock = pygame.time.Clock()
# 管道的实例化类
pipe_sprites = pygame.sprite.Group()
for i in range(2):
pipe_pos = Pipe.randomPipe(cfg, pipe_images.get('top'))
pipe_sprites.add(Pipe(image=pipe_images.get('top'),
position=(cfg.SCREENWIDTH+200+i*cfg.SCREENWIDTH/2,
pipe_pos.get('top')[-1]), type_='top'))
pipe_sprites.add(Pipe(image=pipe_images.get('bottom'),
position=(cfg.SCREENWIDTH+200+i*cfg.SCREENWIDTH/2,
pipe_pos.get('bottom')[-1]), type_='bottom'))
# 小鸟的实例化类
bird = Bird(images=bird_images, idx=bird_idx, position=bird_pos)
# 是否添加管道
is_add_pipe = True
# 是否运行游戏
is_game_running = True
while is_game_running:
for event in pygame.event.get():
if event.type == pygame.QUIT or (event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE):
if mode == 'train': agent.saveModel(modelpath)
pygame.quit()
sys.exit()
# 使用强化学习算法玩游戏
delta_x = 10000
delta_y = 10000
for pipe in pipe_sprites:
if pipe.type_ == 'bottom' and (pipe.rect.left-bird.rect.left+30) > 0:
if pipe.rect.right - bird.rect.left < delta_x:
delta_x = pipe.rect.left - bird.rect.left
delta_y = pipe.rect.top - bird.rect.top
delta_x = int((delta_x + 60) / 5)
delta_y = int((delta_y + 225) / 5)
if agent.act(delta_x, delta_y, int(bird.speed+9)):
bird.setFlapped()
sounds['wing'].play()
# --检查鸟和管子之间的碰撞
for pipe in pipe_sprites:
if pygame.sprite.collide_mask(bird, pipe):
sounds['hit'].play()
is_game_running = False
# --更新小鸟
boundary_values = [0, base_pos[-1]]
is_dead = bird.update(boundary_values)
if is_dead:
sounds['hit'].play()
is_game_running = False
# --如果游戏结束,存储参数
if not is_game_running:
agent.update(score, True) if mode == 'train' else agent.update(score, False)
# --向左移动背景实现小鸟飞行的效果
base_pos[0] = -((-base_pos[0] + 4) % base_diff_bg)
# --向左移动管道实现小鸟飞行的效果
flag = False
reward = 1
for pipe in pipe_sprites:
pipe.rect.left -= 4
if pipe.rect.centerx <= bird.rect.centerx and not pipe.used_for_score:
pipe.used_for_score = True
score += 0.5
reward = 5
if '.5' in str(score):
sounds['point'].play()
if pipe.rect.left < 5 and pipe.rect.left > 0 and is_add_pipe:
pipe_pos = Pipe.randomPipe(cfg, pipe_images.get('top'))
pipe_sprites.add(Pipe(image=pipe_images.get('top'), position=pipe_pos.get('top'), type_='top'))
pipe_sprites.add(Pipe(image=pipe_images.get('bottom'), position=pipe_pos.get('bottom'), type_='bottom'))
is_add_pipe = False
elif pipe.rect.right < 0:
pipe_sprites.remove(pipe)
flag = True
if flag: is_add_pipe = True
# --设定反馈
if mode == 'train' and is_game_running:
agent.setReward(reward)
# --绑定必要的游戏元素在屏幕上
screen.blit(backgroud_image, (0, 0))
pipe_sprites.draw(screen)
screen.blit(other_images['base'], base_pos)
showScore(screen, score, number_images)
bird.draw(screen)
pygame.display.update()
clock.tick(cfg.FPS)
# 游戏结束界面
endGame(screen, sounds, showScore, score, number_images, bird, pipe_sprites, backgroud_image, other_images, base_pos, cfg, mode)
'''运行'''
if __name__ == '__main__':
# 解析命令行中的参数
args = parseArgs()
mode = args.mode.lower()
policy = args.policy.lower()
assert mode in ['train', 'test'], '--mode should be <train> or <test>'
assert policy in ['plain', 'greedy'], '--policy should be <plain> or <greedy>'
# 强化学习的实例化类, 并且保存路径和加载模型
if not os.path.exists('checkpoints'):
os.mkdir('checkpoints')
agent = QLearningAgent(mode) if policy == 'plain' else QLearningGreedyAgent(mode)
modelpath = 'checkpoints/qlearning_%s.pkl' % policy
if policy == 'greedy':
modelpath = 'checkpoints/qlearning_%s.pkl' % policy
if os.path.isfile(modelpath):
agent.loadModel(modelpath)
# 开始游戏
while True:
main(mode, policy, agent, modelpath)
这段代码是最后完整的Flappy Bird游戏实现,它就是使用了前面介绍过的Q学习算法来训练一个智能体自动玩这个游戏。代码分为几个主要部分:参数解析、游戏初始化、游戏画面显示、主函数定义和运行。
参数解析 (parseArgs
)
- 使用
argparse
库来解析命令行参数,包括模式(训练或测试)和策略(普通或ε-greedy)。
游戏初始化 (initGame
)
- 初始化Pygame,设置游戏窗口和标题。
游戏画面显示 (showScore
)
- 在屏幕上显示当前分数。
主函数定义 (main
)
- 初始化游戏资源,包括声音、图片、小鸟、管道等。
- 根据选择的模式(训练或测试)和策略(普通或ε-greedy)初始化智能体。
- 进入游戏主循环,处理事件、更新游戏状态、检测碰撞、更新分数、绘制游戏画面等。
- 游戏结束后,显示结束界面。
运行
- 检查并创建模型保存目录。
- 根据命令行参数初始化智能体,并加载模型(如果存在)。
- 调用
main
函数开始游戏。
总结
好的,这个项目的主要部分呢其实就总结的差不多了,我也不知道该不该这么总结一个项目,总之,有什么意见或者建议以及大家个人的一些有趣的想法,都可以和我交流,希望大家可以完美地运行这个项目哦!!!