在看Trajectory Transformer代码时候,使用halfcheetah-medium-v2数据集跑了一下,发现Transformer的输入维度为[256,249],其中256为batchsize,但是这个249是什么一直没理解。
首先从官方文档可以知道(Half Cheetah - Gym Documentation)
halfcheetah-medium-v2的observation为17维,actions为6维,reward为1维,value是通过reward计算得到的,也是1维,它定义了一个transition_dim将以上拼接起来,得到25维。
(train.py)
通过追溯Dataloader,参考链接pytorch读取数据(Dataset, DataLoader, DataLoaderIter)_pytorch中怎么访问dataloader里的单个元素-CSDN博客发现在Dataset中,
(sequence.py)
至此,[256, 249]的来源已搞清。
总结,当发现问题时,还是需要从源头来解决,数据不知道怎么来的,就看dataloader。后面还需要好好看看,学无止境啊