用深度强化学习来玩Chrome小恐龙快跑

目录

实机演示

代码实现


实机演示

用深度强化学习来玩Chrome小恐龙快跑

代码实现

import os
import cv2
from pygame import RLEACCEL
from pygame.image import load
from pygame.sprite import Sprite, Group, collide_mask
from pygame import Rect, init, time, display, mixer, transform, Surface
from pygame.surfarray import array3d
import torch
from random import randrange, choice
import numpy as np

mixer.pre_init(44100, -16, 2, 2048)
init()

scr_size = (width, height) = (600, 150)
FPS = 60
gravity = 0.6

black = (0, 0, 0)
white = (255, 255, 255)
background_col = (235, 235, 235)

high_score = 0

screen = display.set_mode(scr_size)
clock = time.Clock()
display.set_caption("T-Rex Rush")


def load_image(
        name,
        sizex=-1,
        sizey=-1,
        colorkey=None,
):
    fullname = os.path.join("assets/sprites", name)
    image = load(fullname)
    image = image.convert()
    if colorkey is not None:
        if colorkey is -1:
            colorkey = image.get_at((0, 0))
        image.set_colorkey(colorkey, RLEACCEL)

    if sizex != -1 or sizey != -1:
        image = transform.scale(image, (sizex, sizey))

    return (image, image.get_rect())


def load_sprite_sheet(
        sheetname,
        nx,
        ny,
        scalex=-1,
        scaley=-1,
        colorkey=None,
):
    fullname = os.path.join("assets/sprites", sheetname)
    sheet = load(fullname)
    sheet = sheet.convert()

    sheet_rect = sheet.get_rect()

    sprites = []

    sizey = sheet_rect.height / ny
    if isinstance(nx, int):
        sizex = sheet_rect.width / nx
        for i in range(0, ny):
            for j in range(0, nx):
                rect = Rect((j * sizex, i * sizey, sizex, sizey))
                image = Surface(rect.size)
                image = image.convert()
                image.blit(sheet, (0, 0), rect)

                if colorkey is not None:
                    if colorkey is -1:
                        colorkey = image.get_at((0, 0))
                    image.set_colorkey(colorkey, RLEACCEL)

                if scalex != -1 or scaley != -1:
                    image = transform.scale(image, (scalex, scaley))

                sprites.append(image)

    else:  #list
        sizex_ls = [sheet_rect.width / i_nx for i_nx in nx]
        for i in range(0, ny):
            for i_nx, sizex, i_scalex in zip(nx, sizex_ls, scalex):
                for j in range(0, i_nx):
                    rect = Rect((j * sizex, i * sizey, sizex, sizey))
                    image = Surface(rect.size)
                    image = image.convert()
                    image.blit(sheet, (0, 0), rect)

                    if colorkey is not None:
                        if colorkey is -1:
                            colorkey = image.get_at((0, 0))
                        image.set_colorkey(colorkey, RLEACCEL)

                    if i_scalex != -1 or scaley != -1:
                        image = transform.scale(image, (i_scalex, scaley))

                    sprites.append(image)

    sprite_rect = sprites[0].get_rect()

    return sprites, sprite_rect


def extractDigits(number):
    if number > -1:
        digits = []
        i = 0
        while (number / 10 != 0):
            digits.append(number % 10)
            number = int(number / 10)

        digits.append(number % 10)
        for i in range(len(digits), 5):
            digits.append(0)
        digits.reverse()
        return digits


def pre_processing(image, w=84, h=84):
    image = image[:300, :, :]
    # cv2.imwrite("ori.jpg", image)
    image = cv2.cvtColor(cv2.resize(image, (w, h)), cv2.COLOR_BGR2GRAY)
    # cv2.imwrite("color.jpg", image)
    _, image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
    # cv2.imwrite("bw.jpg", image)

    return image[None, :, :].astype(np.float32)


class Dino():
    def __init__(self, sizex=-1, sizey=-1):
        self.images, self.rect = load_sprite_sheet("dino.png", 5, 1, sizex, sizey, -1)
        self.images1, self.rect1 = load_sprite_sheet("dino_ducking.png", 2, 1, 59, sizey, -1)
        self.rect.bottom = int(0.98 * height)
        self.rect.left = width / 15
        self.image = self.images[0]
        self.index = 0
        self.counter = 0
        self.score = 0
        self.isJumping = False
        self.isDead = False
        self.isDucking = False
        self.isBlinking = False
        self.movement = [0, 0]
        self.jumpSpeed = 11.5

        self.stand_pos_width = self.rect.width
        self.duck_pos_width = self.rect1.width

    def draw(self):
        screen.blit(self.image, self.rect)

    def checkbounds(self):
        if self.rect.bottom > int(0.98 * height):
            self.rect.bottom = int(0.98 * height)
            self.isJumping = False

    def update(self):
        if self.isJumping:
            self.movement[1] = self.movement[1] + gravity

        if self.isJumping:
            self.index = 0
        elif self.isBlinking:
            if self.index == 0:
                if self.counter % 400 == 399:
                    self.index = (self.index + 1) % 2
            else:
                if self.counter % 20 == 19:
                    self.index = (self.index + 1) % 2

        elif self.isDucking:
            if self.counter % 5 == 0:
                self.index = (self.index + 1) % 2
        else:
            if self.counter % 5 == 0:
                self.index = (self.index + 1) % 2 + 2

        if self.isDead:
            self.index = 4

        if not self.isDucking:
            self.image = self.images[self.index]
            self.rect.width = self.stand_pos_width
        else:
            self.image = self.images1[(self.index) % 2]
            self.rect.width = self.duck_pos_width

        self.rect = self.rect.move(self.movement)
        self.checkbounds()

        if not self.isDead and self.counter % 7 == 6 and self.isBlinking == False:
            self.score += 1

        self.counter = (self.counter + 1)


class Cactus(Sprite):
    def __init__(self, speed=5, sizex=-1, sizey=-1):
        Sprite.__init__(self, self.containers)
        self.images, self.rect = load_sprite_sheet("cacti-small.png", [2, 3, 6], 1, sizex, sizey, -1)
        self.rect.bottom = int(0.98 * height)
        self.rect.left = width + self.rect.width
        self.image = self.images[randrange(0, 11)]
        self.movement = [-1 * speed, 0]

    def draw(self):
        screen.blit(self.image, self.rect)

    def update(self):
        self.rect = self.rect.move(self.movement)

        if self.rect.right < 0:
            self.kill()


class Ptera(Sprite):
    def __init__(self, speed=5, sizex=-1, sizey=-1):
        Sprite.__init__(self, self.containers)
        self.images, self.rect = load_sprite_sheet("ptera.png", 2, 1, sizex, sizey, -1)
        self.ptera_height = [height * 0.82, height * 0.75, height * 0.60, height * 0.48]
        self.rect.centery = self.ptera_height[randrange(0, 4)]
        self.rect.left = width + self.rect.width
        self.image = self.images[0]
        self.movement = [-1 * speed, 0]
        self.index = 0
        self.counter = 0

    def draw(self):
        screen.blit(self.image, self.rect)

    def update(self):
        if self.counter % 10 == 0:
            self.index = (self.index + 1) % 2
        self.image = self.images[self.index]
        self.rect = self.rect.move(self.movement)
        self.counter = (self.counter + 1)
        if self.rect.right < 0:
            self.kill()


class Ground():
    def __init__(self, speed=-5):
        self.image, self.rect = load_image("ground.png", -1, -1, -1)
        self.image1, self.rect1 = load_image("ground.png", -1, -1, -1)
        self.rect.bottom = height
        self.rect1.bottom = height
        self.rect1.left = self.rect.right
        self.speed = speed

    def draw(self):
        screen.blit(self.image, self.rect)
        screen.blit(self.image1, self.rect1)

    def update(self):
        self.rect.left += self.speed
        self.rect1.left += self.speed

        if self.rect.right < 0:
            self.rect.left = self.rect1.right

        if self.rect1.right < 0:
            self.rect1.left = self.rect.right


class Cloud(Sprite):
    def __init__(self, x, y):
        Sprite.__init__(self, self.containers)
        self.image, self.rect = load_image("cloud.png", int(90 * 30 / 42), 30, -1)
        self.speed = 1
        self.rect.left = x
        self.rect.top = y
        self.movement = [-1 * self.speed, 0]

    def draw(self):
        screen.blit(self.image, self.rect)

    def update(self):
        self.rect = self.rect.move(self.movement)
        if self.rect.right < 0:
            self.kill()


class Scoreboard():
    def __init__(self, x=-1, y=-1):
        self.score = 0
        self.tempimages, self.temprect = load_sprite_sheet("numbers.png", 12, 1, 11, int(11 * 6 / 5), -1)
        self.image = Surface((55, int(11 * 6 / 5)))
        self.rect = self.image.get_rect()
        if x == -1:
            self.rect.left = width * 0.89
        else:
            self.rect.left = x
        if y == -1:
            self.rect.top = height * 0.1
        else:
            self.rect.top = y

    def draw(self):
        screen.blit(self.image, self.rect)

    def update(self, score):
        score_digits = extractDigits(score)
        self.image.fill(background_col)
        if len(score_digits) == 6:
            score_digits = score_digits[1:]
        for s in score_digits:
            self.image.blit(self.tempimages[s], self.temprect)
            self.temprect.left += self.temprect.width
        self.temprect.left = 0


class ChromeDino(object):
    def __init__(self):
        self.gamespeed = 5
        self.gameOver = False
        self.gameQuit = False
        self.playerDino = Dino(44, 47)
        self.new_ground = Ground(-1 * self.gamespeed)
        self.scb = Scoreboard()
        self.highsc = Scoreboard(width * 0.78)
        self.counter = 0

        self.cacti = Group()
        self.pteras = Group()
        self.clouds = Group()
        self.last_obstacle = Group()

        Cactus.containers = self.cacti
        Ptera.containers = self.pteras
        Cloud.containers = self.clouds

        self.retbutton_image, self.retbutton_rect = load_image("replay_button.png", 35, 31, -1)
        self.gameover_image, self.gameover_rect = load_image("game_over.png", 190, 11, -1)

        self.temp_images, self.temp_rect = load_sprite_sheet("numbers.png", 12, 1, 11, int(11 * 6 / 5), -1)
        self.HI_image = Surface((22, int(11 * 6 / 5)))
        self.HI_rect = self.HI_image.get_rect()
        self.HI_image.fill(background_col)
        self.HI_image.blit(self.temp_images[10], self.temp_rect)
        self.temp_rect.left += self.temp_rect.width
        self.HI_image.blit(self.temp_images[11], self.temp_rect)
        self.HI_rect.top = height * 0.1
        self.HI_rect.left = width * 0.73

    def step(self, action, record=False):  # 0: Do nothing. 1: Jump. 2: Duck
        reward = 0.1
        if action == 0:
            reward += 0.01
            self.playerDino.isDucking = False
        elif action == 1:
            self.playerDino.isDucking = False
            if self.playerDino.rect.bottom == int(0.98 * height):
                self.playerDino.isJumping = True
                self.playerDino.movement[1] = -1 * self.playerDino.jumpSpeed

        elif action == 2:
            if not (self.playerDino.isJumping and self.playerDino.isDead) and self.playerDino.rect.bottom == int(
                    0.98 * height):
                self.playerDino.isDucking = True

        for c in self.cacti:
            c.movement[0] = -1 * self.gamespeed
            if collide_mask(self.playerDino, c):
                self.playerDino.isDead = True
                reward = -1
                break
            else:
                if c.rect.right < self.playerDino.rect.left < c.rect.right + self.gamespeed + 1:
                    reward = 1
                    break

        for p in self.pteras:
            p.movement[0] = -1 * self.gamespeed
            if collide_mask(self.playerDino, p):
                self.playerDino.isDead = True
                reward = -1
                break
            else:
                if p.rect.right < self.playerDino.rect.left < p.rect.right + self.gamespeed + 1:
                    reward = 1
                    break

        if len(self.cacti) < 2:
            if len(self.cacti) == 0 and len(self.pteras) == 0:
                self.last_obstacle.empty()
                self.last_obstacle.add(Cactus(self.gamespeed, [60, 40, 20], choice([40, 45, 50])))
            else:
                for l in self.last_obstacle:
                    if l.rect.right < width * 0.7 and randrange(0, 50) == 10:
                        self.last_obstacle.empty()
                        self.last_obstacle.add(Cactus(self.gamespeed, [60, 40, 20], choice([40, 45, 50])))

        # if len(self.pteras) == 0 and randrange(0, 200) == 10 and self.counter > 500:
        if len(self.pteras) == 0 and len(self.cacti) < 2 and randrange(0, 50) == 10 and self.counter > 500:
            for l in self.last_obstacle:
                if l.rect.right < width * 0.8:
                    self.last_obstacle.empty()
                    self.last_obstacle.add(Ptera(self.gamespeed, 46, 40))

        if len(self.clouds) < 5 and randrange(0, 300) == 10:
            Cloud(width, randrange(height / 5, height / 2))

        self.playerDino.update()
        self.cacti.update()
        self.pteras.update()
        self.clouds.update()
        self.new_ground.update()
        self.scb.update(self.playerDino.score)

        state = display.get_surface()
        screen.fill(background_col)
        self.new_ground.draw()
        self.clouds.draw(screen)
        self.scb.draw()
        self.cacti.draw(screen)
        self.pteras.draw(screen)
        self.playerDino.draw()

        display.update()
        clock.tick(FPS)

        if self.playerDino.isDead:
            self.gameOver = True

        self.counter = (self.counter + 1)

        if self.gameOver:
            self.__init__()

        state = array3d(state)
        if record:
            return torch.from_numpy(pre_processing(state)), np.transpose(
                cv2.cvtColor(state, cv2.COLOR_RGB2BGR), (1, 0, 2)), reward, not (reward > 0)
        else:
            return torch.from_numpy(pre_processing(state)), reward, not (reward > 0)
import torch.nn as nn

class DeepQNetwork(nn.Module):
    def __init__(self):
        super(DeepQNetwork, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))

        self.fc1 = nn.Sequential(nn.Linear(7 * 7 * 64, 512), nn.ReLU(inplace=True))
        self.fc2 = nn.Linear(512, 3)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.uniform_(m.weight, -0.01, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, input):
        output = self.conv1(input)
        output = self.conv2(output)
        output = self.conv3(output)
        output = output.view(output.size(0), -1)
        output = self.fc1(output)
        output = self.fc2(output)

        return output
import argparse
import torch

from src.model import DeepQNetwork
from src.env import ChromeDino
import cv2


def get_args():
    parser = argparse.ArgumentParser(
        """Implementation of Deep Q Network to play Chrome Dino""")
    parser.add_argument("--saved_path", type=str, default="trained_models")
    parser.add_argument("--fps", type=int, default=60, help="frames per second")
    parser.add_argument("--output", type=str, default="output/chrome_dino.mp4", help="the path to output video")

    args = parser.parse_args()
    return args


def q_test(opt):
    if torch.cuda.is_available():
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)
    model = DeepQNetwork()
    checkpoint_path = "{}/chrome_dino.pth".format(opt.saved_path)
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    env = ChromeDino()
    state, raw_state, _, _ = env.step(0, True)
    state = torch.cat(tuple(state for _ in range(4)))[None, :, :, :]
    if torch.cuda.is_available():
        model.cuda()
        state = state.cuda()
    out = cv2.VideoWriter(opt.output, cv2.VideoWriter_fourcc(*"MJPG"), opt.fps, (600, 150))
    done = False
    while not done:
        prediction = model(state)[0]
        action = torch.argmax(prediction).item()
        next_state, raw_next_state, reward, done = env.step(action, True)
        out.write(raw_next_state)
        if torch.cuda.is_available():
            next_state = next_state.cuda()
        next_state = torch.cat((state[0, 1:, :, :], next_state))[None, :, :, :]
        state = next_state



if __name__ == "__main__":
    opt = get_args()
    q_test(opt)
import argparse
import os
from random import random, randint, sample
import pickle
import numpy as np
import torch
import torch.nn as nn

from src.model import DeepQNetwork
from src.env import ChromeDino


def get_args():
    parser = argparse.ArgumentParser(
        """Implementation of Deep Q Network to play Chrome Dino""")
    parser.add_argument("--batch_size", type=int, default=64, help="The number of images per batch")
    parser.add_argument("--optimizer", type=str, choices=["sgd", "adam"], default="adam")
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--initial_epsilon", type=float, default=0.1)
    parser.add_argument("--final_epsilon", type=float, default=1e-4)
    parser.add_argument("--num_decay_iters", type=float, default=2000000)
    parser.add_argument("--num_iters", type=int, default=2000000)
    parser.add_argument("--replay_memory_size", type=int, default=50000,
                        help="Number of epoches between testing phases")
    parser.add_argument("--saved_folder", type=str, default="trained_models")

    args = parser.parse_args()
    return args


def train(opt):
    if torch.cuda.is_available():
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)
    model = DeepQNetwork()
    if torch.cuda.is_available():
        model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    if not os.path.isdir(opt.saved_folder):
        os.makedirs(opt.saved_folder)
    checkpoint_path = os.path.join(opt.saved_folder, "chrome_dino.pth")
    memory_path = os.path.join(opt.saved_folder, "replay_memory.pkl")
    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        iter = checkpoint["iter"] + 1
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        print("Load trained model from iteration {}".format(iter))
    else:
        iter = 0
    if os.path.isfile(memory_path):
        with open(memory_path, "rb") as f:
            replay_memory = pickle.load(f)
        print("Load replay memory")
    else:
        replay_memory = []
    criterion = nn.MSELoss()
    env = ChromeDino()
    state, _, _ = env.step(0)
    state = torch.cat(tuple(state for _ in range(4)))[None, :, :, :]
    while iter < opt.num_iters:
        if torch.cuda.is_available():
            prediction = model(state.cuda())[0]
        else:
            prediction = model(state)[0]
        # Exploration or exploitation
        epsilon = opt.final_epsilon + (
                max(opt.num_decay_iters - iter, 0) * (opt.initial_epsilon - opt.final_epsilon) / opt.num_decay_iters)
        u = random()
        random_action = u <= epsilon
        if random_action:
            action = randint(0, 2)
        else:
            action = torch.argmax(prediction).item()

        next_state, reward, done = env.step(action)
        next_state = torch.cat((state[0, 1:, :, :], next_state))[None, :, :, :]
        replay_memory.append([state, action, reward, next_state, done])
        if len(replay_memory) > opt.replay_memory_size:
            del replay_memory[0]
        batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*batch)

        state_batch = torch.cat(tuple(state for state in state_batch))
        action_batch = torch.from_numpy(
            np.array([[1, 0, 0] if action == 0 else [0, 1, 0] if action == 1 else [0, 0, 1] for action in
                      action_batch], dtype=np.float32))
        reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])
        next_state_batch = torch.cat(tuple(state for state in next_state_batch))

        if torch.cuda.is_available():
            state_batch = state_batch.cuda()
            action_batch = action_batch.cuda()
            reward_batch = reward_batch.cuda()
            next_state_batch = next_state_batch.cuda()
        current_prediction_batch = model(state_batch)
        next_prediction_batch = model(next_state_batch)

        y_batch = torch.cat(
            tuple(reward if done else reward + opt.gamma * torch.max(prediction) for reward, done, prediction in
                  zip(reward_batch, done_batch, next_prediction_batch)))

        q_value = torch.sum(current_prediction_batch * action_batch, dim=1)
        optimizer.zero_grad()
        loss = criterion(q_value, y_batch)
        loss.backward()
        optimizer.step()

        state = next_state
        iter += 1
        print("Iteration: {}/{}, Loss: {:.5f}, Epsilon {:.5f}, Reward: {}".format(
            iter + 1,
            opt.num_iters,
            loss,
            epsilon, reward))
        if (iter + 1) % 50000 == 0:
            checkpoint = {"iter": iter,
                          "model_state_dict": model.state_dict(),
                          "optimizer": optimizer.state_dict()}
            torch.save(checkpoint, checkpoint_path)
            with open(memory_path, "wb") as f:
                pickle.dump(replay_memory, f, protocol=pickle.HIGHEST_PROTOCOL)


if __name__ == "__main__":
    opt = get_args()
    train(opt)

  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
【资源说明】 Python开发基于机器学习实现自动Google小恐龙游戏源码+项目说明+注释拉满.zip 如何用人工智能自动游戏 ## 一、前言 让AI游戏的思想早在上世纪就已经有了,那个时候更偏向棋类游戏。像是五子棋、象棋等。在上世纪“深蓝”就击败了国际象棋冠军,而到2016年“Alpha Go”击败了人类围棋冠军。 到现在,AI涉略的不仅仅是棋类游戏。像是超级马里奥、王者荣耀这种游戏,AI也能有比较好的表现。今天我们就来用一个实际的例子讨论AI自动游戏这一话题,本文会用非常简单的机器学习算法让AI自动Google小恐龙游戏。 ## 二、Google小恐龙与监督学习 ### 2.1、Google小恐龙 如果你使用的是Chrome浏览器,那么相信你应该见过下面这个恐龙: ![在这里插入图片描述](https://img-blog.csdnimg.cn/f6cd242ae82248759165688539d6c521.png#pic_center) 当我们用Chrome断网访问网页时,就会显示这个恐龙,或者直接在地址栏输入:[chrome://dino](chrome://dino)直接访问该游戏。 游戏的法非常简单,只需要按空格键即可。比如下面左图,碰到障碍物,这时需要按空格,而下面右图没有障碍(或离障碍比较远),则不需要按按键。 ![在这里插入图片描述](https://img-blog.csdnimg.cn/2ba95fbf96e84d0aa3bc49fe1fb9fc9b.png#pic_center) 当然还有出现鸟的情况,我们也可以归为跳的情况。大家可以一下。 ### 2.2、监督学习 游戏很多时候会使用一个叫强化学习的方式来实现,而本文使用比较简单的监督学习来实现。 【备注】 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载使用,也适合小白学习进阶,当然也可作为毕设项目、课程设计、作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能,也可直接用于毕设、课设、作业等。 欢迎下载,沟通交流,互相学习,共同进步!
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值