CartPole-v1无限步数(Gym,pytorch、DQN)

基于常规DQN算法(无算法层面的优化)实现CartPole-v1无限步数

唯一的改动就是把奖励和当前位置产生了联系,从零开始的话大概训练1000轮左右就能无限步数

这段代码比较核心的改进也就是reward -= abs(state[0])*5这句话了,因为CartPole位置离中心太远(超出边界)的话就会寄掉,加上这句话能让模型一直保持在中间的位置,同时极大的提高了收敛速度

视频时长只有22分是因为只能存这么长,再长的话就会out of memory了

q网络的代码(参数不是特别重要,改成别的大概也能拟合),保存为qnet.py

import os
import random
import torch
import torch.nn as nn
device = torch.device("cuda")
Batch_size = 64 # 每次训练的样本数
LR = 1e-3 # 学习率
FCN_size = (256,256)
Path = 'cache/qnet'

class FCN(nn.Module):
    def __init__(self):
        super().__init__()
        self.sequense = nn.Sequential(
            nn.Linear(4, FCN_size[0]),
            nn.ReLU(),
            nn.Linear(FCN_size[0], FCN_size[1]),
            nn.ReLU(),
            nn.Linear(FCN_size[1], 2)
        )
    def forward(self, x):
        return self.sequense(x)
class Qnet:
    def __init__(self):
        self.q1 = FCN().to(device)
        self.q2 = FCN().to(device)
        if os.path.exists(Path):
            self.q1 = torch.load(Path)
            self.q2 = torch.load(Path)
        self.optimizer = torch.optim.Adam(self.q1.parameters(), lr=LR)
        self.loss_fn = nn.MSELoss().to(device)
    def update(self):
        torch.save(self.q1,Path)
        self.q2 = torch.load(Path)
    def start(self,pool):
        if len(pool) < Batch_size:return
        da = random.sample(pool,Batch_size)
        s0,a0,r1,s1 = zip(*da)
        s0,a0,r1,s1 = torch.stack(s0),torch.stack(a0),torch.stack(r1),torch.stack(s1)

        self.q1.train()
        y_p = self.q1(s0).gather(1,a0.long().reshape(-1,1))
        q_val = self.q1(s1).max(1)[1].view(-1,1)
        next_q_val = self.q1(s1).gather(1,q_val)
        y = r1.reshape(-1,1)+next_q_val

        loss = self.loss_fn(y_p, y)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.q1.eval()

dqn的代码(这里需要注意的是经验池大小,太小的话容易卡在几千回合就结束了;还有就是贪婪度也别太高,不然容易因为运气太差中途寄掉,其他没什么特别重要的了),保存为dqn.py

import os
import random
import time
import torch
from collections import deque
import gymnasium as gym
import qnet
device = torch.device("cuda")
Start_time = time.time()

Episode = 2*10**6  # 训练次数
Path_best_model = 'cache/dqn_best'

class Dqn_agent:
    Epsilon = 0.01   #贪婪度,0贪心1随机
    Mx_pool = 3000 # 经验池大小
    Up_model_ = 1000 # 更新模型频率
    def __init__(self):
        self.pool = deque([])
        self.cnt = 0
        self.q = qnet.Qnet()
        if os.path.exists(Path_best_model):
            self.q.q1 = torch.load(Path_best_model)
            self.q.q2 = torch.load(Path_best_model)
    def update(self,state,action,reward,next_state,terminated):
        self.cnt += 1
        self.q.start(self.pool)
        if self.cnt % self.Up_model_ == 0:
            self.q.update()
        #更新池
        reward -= abs(state[0])*5
        self.pool.append([torch.tensor(x).float().to(device) for x in (state,action,reward,next_state)])
        if len(self.pool) > self.Mx_pool:
            self.pool.popleft()
    def action_choose(self,state):
        if random.random() < self.Epsilon:
            action = random.randint(0,1)
        else:
            lin = self.q.q2(torch.tensor(state).float().to(device)).tolist()
            action = lin.index(max(lin))
        return action

agent = Dqn_agent()
env = gym.make('CartPole-v1', render_mode='human')   #, render_mode='human'
total_mx = 0
for episode in range(1,Episode+1):
    state = env.reset()[0]
    total_reward = 0
    terminated = False
    while not terminated:
        action = agent.action_choose(state)
        next_state, reward, terminated, *_ = env.step(action)
        total_mx += reward
        agent.update(state, action, reward, next_state, terminated)
        state = next_state

    if total_reward > total_mx:# 更新最优模型
        torch.save(agent.q.q2,Path_best_model)
        total_mx = total_reward

env.close()

  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值