d3rlpy离线强化学习算法库安装及使用

GitHub - takuseno/d3rlpy: An offline deep reinforcement learning library

d3rlpy,离线强化学习算法库

我装在windows下用anaconda,按照官网教程

conda install -c conda-forge d3rlpy

第一次安装报错CondaSSLError: OpenSSL appears to be unavailable on this machine

[报错解决]CondaSSLError: OpenSSL appears to be unavailable on this machine. OpenSSL is required to downl_一件迷途小书童的博客-CSDN博客

参考这篇文章解决后正常安装没问题,值得注意的是d3rkpy安装时包含cudatoolkit11.几,我在想这个在不同电脑上可能之后会出错,不过后面运行算法时可以选择是否使用GPU

我是打算用离线强化学习算法,安装后测试,官网上也有测试代码

import d3rlpy

# prepare dataset
dataset, env = d3rlpy.datasets.get_d4rl('hopper-medium-v0')

# prepare algorithm
cql = d3rlpy.algos.CQL(use_gpu=True)

# train
cql.fit(
    dataset,
    eval_episodes=dataset,
    n_epochs=100,
    scorers={
        'environment': d3rlpy.metrics.evaluate_on_environment(env),
        'td_error': d3rlpy.metrics.td_error_scorer,
    },
)

看得出来,这接口用起来非常方便啊

因为我没装d4rl所以肯定是失败了,d4rl数据集查了下资料可能无法装在windows环境下,有点难办。可以使用下面这个在测试,用的是d3rlpy自带用于测试的数据集,也是比较常用的两个环境,具体是在d3rlpy的文档上找到的

import d3rlpy

# prepare dataset
# dataset, env = d3rlpy.datasets.get_d4rl('CartPole-v0')
dataset, env = d3rlpy.datasets.get_pendulum("random")

# prepare algorithm
cql = d3rlpy.algos.CQL(use_gpu=True)

# train
cql.fit(
    dataset,
    eval_episodes=dataset,
    n_epochs=100,
    scorers={
        'environment': d3rlpy.metrics.evaluate_on_environment(env),
        'td_error': d3rlpy.metrics.td_error_scorer,
    },
)

资料很充分,d3rlpy文档:d3rlpy.datasets.get_cartpole — d3rlpy documentation

 成功运行:

如果失败的话可能是下载失败,

在这找到下载网址,自己下载到本地,改成规定的名字即可,放到对d3rlpy_data文件夹里,再运行时就不需要在线下载了,比如这样

 

之后回到d4rl,我打算把自己的数据集按照d4rl的格式来编写,但我不打算装d4rl

可以看到在d3rlpy中读取d4rl的数据集主要是用d4rl中的get_dataset函数,于是我索性把d4rl中这个函数搬到d3rlpy中,其实就是读取h5格式的函数,也挺好移植,主要也就这一段

        data_dict = {}
        with h5py.File(h5path, 'r') as dataset_file:
            for k in tqdm(get_keys(dataset_file), desc="load datafile"):
                try:  # first try loading as an array
                    data_dict[k] = dataset_file[k][:]
                except ValueError as e:  # try loading as a scalar
                    data_dict[k] = dataset_file[k][()]

注意还需要

import h5py
from tqdm import tqdm


def get_keys(h5file):
    keys = []

    def visitor(name, item):
        if isinstance(item, h5py.Dataset):
            keys.append(name)

    h5file.visititems(visitor)
    return keys

至于原先是个类,我感觉好像也不需要,同时还是把在线改掉,直接变成一个绝对位置(这个在d4rl中也可以找到下载的网址)

h5path = "D:\xxx_project\pycharm\offline_RL\d3rlpy_data\hopper_random.hdf5"

运行成功

我考虑下一步制作自己的hdf5格式数据集,及做下自己的gym环境

甚至不能算是入门,希望没有问题,欢迎指正

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

大大马猴

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值