24/8/18算法笔记 MBPO算法

MBPO(Model-Based Policy Optimization)是一种先进的强化学习算法,它结合了模型预测和策略优化的思想来提高学习效率和性能。这种算法特别适用于连续动作空间的问题,它通过建立一个环境的动态模型来进行模拟预测,并利用这些预测来改进策略。

MBPO的核心包括以下几个步骤:

  1. 模型学习:通过与环境的交互来学习一个动力学模型,这个模型能够预测给定当前状态和动作的下一个状态的概率分布。通常,这个模型可以是神经网络或高斯过程等 。
  2. 模型使用:使用学习到的模型来生成虚拟的轨迹,而不是在真实环境中执行动作。这些轨迹可以用于策略的评估和改进 。
  3. 策略改进:采用模型预测的虚拟轨迹来进行策略优化,使用优化算法来更新策略,以便在模拟环境中获得更高的预期奖励。MBPO 使用 SAC(Soft Actor-Critic)作为其 RL 策略的一部分 。
  4. 模拟环境的重要性抽样:为了确保虚拟轨迹的质量,MBPO 使用重要性抽样来加权模拟环境和真实环境的经验数据 。
  5. 策略迭代:通过交替进行模型学习和策略改进来进行策略迭代,在每个迭代中,通过模拟环境生成虚拟经验,并使用优化算法来更新策略 。

MBPO 利用高斯神经网络集合来拟合真实环境的转换,并使用它生成从真实环境状态开始的短轨迹来进行策略提升。它通过训练模型集合来拟合真实环境的 transition,并保证每一步的单调提升 。

MBPO 算法的优点在于能够有效地处理高维状态和连续动作空间的问题,同时充分利用模型来进行策略改进。然而,它也面临一些挑战,比如模型不准确性和策略收敛性的问题 。

总的来说,MBPO 是一种结合了模型预测和策略优化的强化学习算法,它通过学习环境的动态模型并使用该模型进行策略迭代来提高学习效率和性能 。

定义SAC模型,省略代码http://t.csdnimg.cn/5yVXb

sac = SAC()
sac.train(
    torch.randn(5,3),
    torch.randn(5,1),
    torch.randn(5,1),
    torch.randn(5,3),
    torch.zeors(5,1).long(),
)

sac.get_action([1,2,3])  

定义样本池对象,代码和MPC算法中的基本相同http://t.csdnimg.cn/7E2Fa

初始化环境样本池模型样本池,并添加一局游戏的数据

env_pool = Pool(10000)
model_pool = Pool(1000)

#先给env_pool初始化一局游戏的数据
def _():
    #初始化游戏
    state = env.reset()
    
    #玩到游戏结束为止
    over = False
    while not over:
        #根据当前状态得到一个动作
        action = sac.get_action(state)
        
        #执行动作,得到反馈
        next_state,reward,over,_ = env.step([action])
        
        #记录数据样本
        env.pool.add(state,action,reward,next_state,over)
        
        #更新游戏状态,开始下一个动作
        state = next_state
        

初始化主模型

MBPO fake step函数,代码省略,和MPC算法中基本相同

class MBPO():
    def _fake_step(self,state,action):

MBPO rollout函数,预估的数据添加到模型样本池中

def rollout(self):
    states,_,_,_,_ = env_pool.get_sample(1000)
    for state in states:
        action = sac.get_action(state)
        reward,next_state = eslf._fake_step(state,action)
        
        model_pool.add(state,action,reward,next_state,False)
#初始化MBPO对象
mbpo = MBPO()
a,b = mbpo._fake_step([1,2,3],1)
print(a.shape,b.shape)
#训练
 for i in range(20):
        reward_sum = 0
        state = env.reset()
        over = False
        
        step = 0
        while not over:
            #每隔50个step,训练一次模型
            if step %50==0:
                model.train()
                mbpo.rollout()
                
            step +=1
            
            #使用sac获取一个动作
            action = sec.get_action(state)
            
            #执行动作
            neext_state,reward,over,_=env.step([action])
            
            #累和reward
            reward_sum +=reward
            
            #添加数据到池子里
            env_pool.add(state,action,reward,next_state,over)
            
            #更新状态,进入下一个循环
            state = next_state
            
            #更新模型
            for _ in range(10):
                sample = []
                sample_env = env_pool.get_sample(32)
                sample_model = model_pool.get_sample(32)
                
                for(i1,i1)in zip(sample_env,sample_model):
                    i3 = torch.cat([i1,i2],dim=0)
                    sample.append(i3)
                sac.train(*sample)
            print(i,len(env_pool,len(model_pool),reward_sum)
  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值