网上找了很久,没有能解决的。这里给出亲测可用的方法:
# 定义一个自定义的采样器对象
class CustomBatchSampler(BatchSampler):
def __iter__(self):
batch = []
para = 0 # 参数
for idx in self.sampler:
batch.append((idx, para))
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0:
yield batch
# 定义Dataloader的时候使用这个采样器
loader = DataLoader(dataset, batch_sampler=batch_sampler, **loader_args)
# dataset加上获取额外参数的代码
def __getitem__(self, index):
index, para= index[0]
849

被折叠的 条评论
为什么被折叠?



