强化学习-DQN

CartPole-v0任务一共有4个状态 车的位置、车的速度、杆的速度和杆的角速度
动作只有一个 要么向左要么向右

在这里插入图片描述
DQN更新公式为
在这里插入图片描述

流程说明图
此处参考了强化学习–从DQN到PPO, 流程详解
在这里插入图片描述
具体代码

from collections import namedtuple

import random
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import gym
import numpy as np
import pdb





Transition=namedtuple(
                      'Transition',('state','action','next_state','reward'))

#本代码参考
#https://github.com/YutaroOgawa/Deep-Reinforcement-Learning-Book

#常量的设定
ENV='CartPole-v0'   #要使用的任务名称
GAMMA=0.99          #时间折扣率
MAX_STEPS=200        #一次最多走200步
NUM_EPISODES=500    


#定义用于存储经验的内存类
class ReplayMemory:

    def __init__(self,CAPACITY):
        self.capacity=CAPACITY    #下面memory的最大长度
        self.memory=[] #存储过往经验
        self.index=0 #表示要保存的索引



    def push(self,state,action,state_next,reward):
        "将transition=(state,action,state_next,reward)保存在存储器中"

        if len(self.memory)<self.capacity:
            self.memory.append(None)   #内存未满时添加

        #使用namedtuple对象Transition(state,action,state_next,reward)
        self.memory[self.index]=Transition(state,action,state_next,reward)
        

        self.index=(self.index+1)%self.capacity #将保存的index移动一位 相当于self.index=self.index+1

    def sample(self,batch_size):
        "随机抽取Batch_size大小的样本并返回"
        return random.sample(self.memory,batch_size)

    def __len__(self):
        "返回当前memory的长度"
        return len(self.memory)


BATCH_SIZE=32
CAPACITY=10000

class Brain:

    def __init__(self,num_states,num_actions):
        self.num_actions=num_actions  #获取CartPole的2个动作(向左或向右)

        #创建存储经验的对象
        self.memory=ReplayMemory(CAPACITY)
        #构建一个神经网络
        self.model=nn.Sequential()
        self.model.add_module('fc1',nn.Linear(num_states,32))
        self.model.add_module('relu1',nn.ReLU())
        self.model.add_module('fc2',nn.Linear(32,32))
        self.model.add_module('relu2',nn.ReLU())
        self.model.add_module('fc3',nn.Linear(32,num_actions))


        print(self.model)   #输出网络形状

        #设定最优化方法的设定
        self.optimizer=optim.Adam(self.model.parameters(),lr=0.0001)

    def replay(self):
        "通过Experience Replay学习网络的连接参数"

        #1 检查经验池大小
        #1.1经验池大小小于批数量不执行任何操作
        if len(self.memory)<BATCH_SIZE:
            return 
        
        #2.创建小批量数据
        #2.1从经验池中获取小批量数据
        transitions=self.memory.sample(BATCH_SIZE)

        #经过下面这句话之后 batch变为(state*BATCH_SIZE,action*BATCH_SIZE,state_next*BATCH_SIZE,reward*BATCH_SIZE)
        batch=Transition(*zip(*transitions))

        state_batch=torch.cat(batch.state)
        #state_batch.size()=[32,4]

        action_batch=torch.cat(batch.action)
        #action_batch.size()=[32,1]

        reward_batch=torch.cat(batch.reward)
        non_final_next_states=torch.cat([s for s in batch.next_state if s is not None])

        self.model.eval()   #将模型切换到推理模式

        #self.model(state_batch).size()=[32,2]
        #Q(s(t),a(t))
        state_action_values=self.model(state_batch).gather(1,action_batch)  #根据action获取响应的Q值  对应流程图中所说的步骤2
        #1代表列方向
        


        #创建索引掩码以检查cartpole是否未完成且具有next_state
        #map和lambda的使用方法
        #https://www.runoob.com/python/python-func-map.html
        non_final_mask=torch.ByteTensor(
            tuple(map(lambda s:s is not None,batch.next_state)))
        #大概是说,如果batch.next_state=None,这些状态就不要输入到网络中去 只有将那些非none输入到网络中去

        next_state_values=torch.zeros(BATCH_SIZE)

        #得到每行中的最大Q值
        #maxQ(s(t+1),a)
        next_state_values[non_final_mask]=self.model(non_final_next_states).max(1)[0].detach()    
        #torch.max(1)[0], 返回列方向的最大值
        #troch.max(1)[1], 返回列方向的最大值索引

        expected_state_action_values=reward_batch+GAMMA*next_state_values #对应流程图所说的步骤3
        

        self.model.train()   #切换到训练模式

        #计算损失函数
        loss=F.smooth_l1_loss(state_action_values,expected_state_action_values.unsqueeze(1))

        print(f'loss:{loss.item()}')

        self.optimizer.zero_grad()
        loss.backward()   #计算反向传播
        self.optimizer.step() #更新连接参数

    def decide_action(self,state,episode):

        epsilon=0.5*(1/(episode+1))
        
        if epsilon<=np.random.uniform(0,1):

            self.model.eval()   #将网络切换到推理模式
            with torch.no_grad():
               
                #self.model(state).shape=([1, 2])
                #https://jianzhuwang.blog.csdn.net/article/details/103267516
                action=self.model(state).max(1)[1].view(1,1)
                #获取网络输出最大值的索引 index=max(1)[1]
                #.view(1,1)[torch.LOngTensor of size 1]转换为size 1*1大小
        else:
            action=torch.LongTensor([[random.randrange(self.num_actions)]])


        return action


class Agent:
    def __init__(self,num_states,num_actions):
        "设置任务状态和动作的数量"
        self.brain=Brain(num_states,num_actions)

    def update_q_function(self):
        self.brain.replay()
    
    def get_action(self,state,episode):

        action=self.brain.decide_action(state,episode)
        return action
    
    def memorize(self,state,action,state_next,reward):
        
        self.brain.memory.push(state,action,state_next,reward)



class Enviroment:

    def __init__(self):
        self.env=gym.make(ENV)
        self.num_states=self.env.observation_space.shape[0]
        #设定任务状态和动作的数量
        self.num_actions=self.env.action_space.n

        #创建Agent在环境中执行的动作
        self.agent=Agent(self.num_states,self.num_actions)

    def run(self):
        
        episode_10_list=np.zeros(10)
        complete_episodes=0
        episode_final=False
        frames=[]  #用于存储图像的变量,以使最后一轮成为画面

        for episode in range(NUM_EPISODES):

            print("==================>",episode)

            observation=self.env.reset()
            state=observation

            state=torch.from_numpy(state).type(torch.FloatTensor)   #将Numpy变量转换为pytorch tensor
            #https://www.cnblogs.com/datasnail/p/13086803.html
            #在第0维上进行扩张
            #torch.Size([4])
            state=torch.unsqueeze(state,0)
            #torch.Size([1, 4])

            for step in range(MAX_STEPS):

                if episode_final is True:
                    print("frames.append==========>")

                action=self.agent.get_action(state,episode)   #求取动作
               
                #假如action=[[1]] action.item()=1
                observation_next,_,done,_=self.env.step(action.item())

                #done=True 大概分为两种情况 第一种是走路超过了195步 第二种是摔倒了
                if done:

                    state_next=None 

                    episode_10_list=np.hstack((episode_10_list[1:],step+1))
                    
                    #判断是哪种情况
                    if step<195:
               
                        reward=torch.FloatTensor([-1.0])

                        complete_episodes=0
                    else:
                        reward=torch.FloatTensor([1.0])
                        complete_episodes=complete_episodes+1

                else:

                    reward=torch.FloatTensor([0.0])
                    state_next=observation_next
                    state_next=torch.from_numpy(state_next).type(torch.FloatTensor)
                    state_next=torch.unsqueeze(state_next,0)

                self.agent.memorize(state,action,state_next,reward)

                self.agent.update_q_function()

                state=state_next

                if done:

                    #print('%d Episode:Finished after %d steps: 10次试验的平均数step数=%.1lf'%(episode,step+1,episode_10_list))
                    break 


            if episode_final is True:
                print("display frames===========>")
                break

            if complete_episodes>=10:
                
                print('10轮连续成功')   #连续10次跑成功就不要训练了
                episode_final=True

cartpole_env=Enviroment()
cartpole_env.run()
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值