pytorch的数据加载模块主要包括dataset和dataloader、sampler,可以说是dataloader包含dataset和sampler,sampler用来生成索引,dataloader中默认的是randosampler,dataset根据生成的索引读取数据和标签,在dataset中必须要写进__getitem__、init、__len__这几个方法,__getitem__用来迭代数据,dataloader用来加载数据,其中有collect_fn这个参数,如果数据中只包含img和label就不用重写这个,但是向目标检测还需要bbox这些信息,所以一般要重写collect_fn
pytorch中dataset和dataloader、sampler关系
最新推荐文章于 2024-04-24 05:01:16 发布