强化学习笔记

##################编写强化学习环境######################

参考教程:

https://www.bilibili.com/video/BV1eT4y197JB/?spm_id_from=333.788&vd_source=5d203a60b40a771f5c500c9e05ddd612icon-default.png?t=O83Ahttps://www.bilibili.com/video/BV1eT4y197JB/?spm_id_from=333.788&vd_source=5d203a60b40a771f5c500c9e05ddd612

import numpy as np
import cv2
from PIL import Image
import time
import pickle
import matplotlib.pyplot as plt
from matplotlib import style
style.use('ggplot')


SIZE=10
EPISODES=30000
SHOW_EVERY=3000

FOOD_REWARD=25
ENEMY_PENALITY=300
MOVE_PENALITY=1
epsilon=0.6
EPS_DECAY=0.9998
DISCOUNT=0.95
LEARNING_RATE=0.1
# q_table='qtable_1726822987.pickle'#打开之前保存运行过的文件
q_table=None
d={1:(255,0,0), #blue
   2:(0,255,0), #green
   3:(0,0,255)} #red
PLAYER_N=1
FOOD_N=2
ENEMY_N=3


class Cube:
    def __init__(self):
        self.x=np.random.randint(0,SIZE)
        self.y=np.random.randint(0,SIZE)
    def __str__(self):
        return f'{self.x},{self.y}'
    
    def __sub__(self,other):
        return (self.x-other.x,self.y-other.y)
    def action(self,choise):
        if choise==0:
            self.move(x=1,y=1)
        elif choise==1:
            self.move(x=-1,y=1)
        elif choise==2:
            self.move(x=1,y=-1)
        elif choise==3:
            self.move(x=-1,y=-1)
 
    def move(self,x=False,y=False):
        if not x:
            self.x+=np.random.randint(-1,2)
        else:
            self.x+=x
        if not y:
            self.y+=np.random.randint(-1,2)
        else:
            self.y+=y
        
        if self.x<0:
            self.x=0
        elif self.x>=SIZE:
            self.x=SIZE-1
        if self.y<0:
            self.y=0
        elif self.y>=SIZE:
            self.y=SIZE-1

        

if q_table is None:
    q_table={}
    for x1 in range(-SIZE+1,SIZE):
        for y1 in range(-SIZE+1,SIZE):
            for x2 in range(-SIZE+1,SIZE):
                for y2 in range(-SIZE+1,SIZE):
                    q_table[((x1,y1),(x2,y2))]=[np.random.uniform(-5,0) for i in range(4)]
else:
    #打开之前的q_table文件
    with open(q_table,'rb') as f:
        q_table=pickle.load(f)



episode_rewards=[]
for episode in range(EPISODES):
    player=Cube()
    food=Cube()
    enemy=Cube()

    if episode%SHOW_EVERY==0:
        print(f'episode #{episode},epsilon:{epsilon}')
        print(f'mean reward:{np.mean(episode_rewards[-SHOW_EVERY:])}')
        show=True
    else:
        show=False

    episode_reward=0
    for i in range (200):
        obs=(player-food,player-enemy)
        if np.random.random()>epsilon:
            action=np.argmax(q_table[obs])
        else:
            action= np.random.randint(0,4)

        player.action(action)

        if player.x==food.x and player.y==food.y:
            reward=FOOD_REWARD
        elif player.x==enemy.x and player.y==enemy.y:
            reward=-ENEMY_PENALITY
        else:
            reward=-MOVE_PENALITY



        #Update the Q_table
        current_q=q_table[obs][action]

        new_obs=(player-food,player-enemy)
 
        max_future_q=np.max(q_table[new_obs])

        
        if reward==FOOD_REWARD:
            new_q=FOOD_REWARD
        else:
            new_q=(1-LEARNING_RATE)*current_q+LEARNING_RATE*(reward+DISCOUNT*max_future_q)
        ##游戏界面可视化
        if show:
            env=np.zeros((SIZE,SIZE,3),dtype=np.uint8)
            env[food.x][food.y]=d[FOOD_N]
            env[player.x][player.y]=d[PLAYER_N]
            env[enemy.x][enemy.y]=d[ENEMY_N]

            img=Image.fromarray(env,'RGB')
            img=img.resize((800,800))
            cv2.imshow('',np.array(img))
            if reward==FOOD_REWARD or reward==-ENEMY_PENALITY:
                if cv2.waitKey(500) & 0xFF ==ord('q'):
                    break
            else:
                if cv2.waitKey(1) & 0xFF ==ord('q'):
                    break


        q_table[obs][action]=new_q
        episode_reward+=reward

        if reward==FOOD_REWARD or reward==-ENEMY_PENALITY:
            break

    episode_rewards.append(episode_reward)
    epsilon*=EPS_DECAY



moving_avg=np.convolve(episode_rewards,np.ones((SHOW_EVERY,))/SHOW_EVERY,mode='valid')
# print(len(moving_avg))
plt.plot([i for i in range(len(moving_avg))],moving_avg)
plt.xlabel('episode #')
plt.ylabel(f'mean{SHOW_EVERY}reward')
plt.show()

#保存q_table文件
with open (f'qtable_{int(time.time())}.pickle','wb')as f:
    pickle.dump(q_table,f)

游戏界面:

奖励曲线:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值