关于gym高版本(0.26.2)在debug中出现的问题

       从网上下载 的强化学习代码,低版本的gym运行没问题,最近更新了系统发现运行报错。

源码:

def train():
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)
    loss_fn = torch.nn.MSELoss()
    #训练N次
    for epoch in range(500):
        #更新N条数据
        update_count, drop_count = update_data()
        #每次更新过数据后,学习N次
        for i in range(200):
            #采样一批数据
            state, action, reward, next_state, over = get_sample()
            #计算一批样本的value和target
            value = get_value(state, action)
            target = get_target(reward, next_state, over)
            #更新参数
            loss = loss_fn(value, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if epoch % 50 == 0:
            test_result = sum([test(play=False) for _ in range(20)]) / 20
            print(epoch, len(datas), update_count, drop_count, test_result)
train()

报错为:

F:\Python3.11\Lib\site-packages\gym\envs\registration.py:555: UserWarning: WARN: The environment CartPole-v0 is out of date. You should consider upgrading to version `v1`.
  logger.warn(
E:\我的文件\PycharmProjects\DQN\main.py:45: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ..\torch\csrc\utils\tensor_new.cpp:248.)
  state = torch.FloatTensor(state).reshape(1, 4)
Traceback (most recent call last):
  File "E:\我的文件\PycharmProjects\DQN\main.py", line 167, in <module>
    train()
  File "E:\我的文件\PycharmProjects\DQN\main.py", line 151, in train
    update_count, drop_count = update_data()
                               ^^^^^^^^^^^^^
  File "E:\我的文件\PycharmProjects\DQN\main.py", line 60, in update_data
    action = get_action(state)
             ^^^^^^^^^^^^^^^^^
  File "E:\我的文件\PycharmProjects\DQN\main.py", line 45, in get_action
    state = torch.FloatTensor(state).reshape(1, 4)
            ^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: expected sequence of length 4 at dim 1 (got 0)

Process finished with exit code 1

最终发现是:

(1)state=env.reset()函数更新了,需要写成:

​state, info = env.reset()

(2)env = gym.make('CartPole-v0')过期了,需要改成:

env = gym.make('CartPole-v1')

(3)state, reward, over, _ = env.step(action),增加了返回值,需要改成:

state, reward, over, _,_ =env.step(action)

(4)converting the list to a single numpy.ndarray,需要将state、next_state 改成:

state = torch.FloatTensor(np.array([i[0]for i in samples]))
next_state = torch.FloatTensor(np.array([i[3] for i in samples]))

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值