Trajectory Transformer代码

文章讲述了作者在研究TrajectoryTransformer代码时对halfcheetah-medium-v2数据集中Transformer输入维度[256,249]的困惑,通过查阅官方文档和追踪PyTorchDataLoader,发现这249源于环境观察、动作、奖励等25维数据的拼接。强调了问题解决应从源头——数据处理开始,学习永无止境。
摘要由CSDN通过智能技术生成

在看Trajectory Transformer代码时候,使用halfcheetah-medium-v2数据集跑了一下,发现Transformer的输入维度为[256,249],其中256为batchsize,但是这个249是什么一直没理解。

首先从官方文档可以知道(Half Cheetah - Gym Documentation

halfcheetah-medium-v2的observation为17维,actions为6维,reward为1维,value是通过reward计算得到的,也是1维,它定义了一个transition_dim将以上拼接起来,得到25维。

(train.py)

通过追溯Dataloader,参考链接pytorch读取数据(Dataset, DataLoader, DataLoaderIter)_pytorch中怎么访问dataloader里的单个元素-CSDN博客发现在Dataset中,

(sequence.py)

至此,[256, 249]的来源已搞清。

总结,当发现问题时,还是需要从源头来解决,数据不知道怎么来的,就看dataloader。后面还需要好好看看,学无止境啊

  • 11
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值