问题描述: 最早一批数据是2500 多张,新一批的数据有36000 多张。
然后最近训练的时候,设置train_nums 设置为36000
但是实际训练的时候,貌似不太行,总是报各种错误
train_nums = 36000
train_sampler = torch.utils.data.WeightedRandomSampler(weights=sampler_weights, num_samples=train_nums,
replacement=True)
原因分析:
看下WeightedRandomSampler 这个函数的具体实现:
反思:
现在才有3w 的数据 运行起来就这么慢
那么像imagenet 这种百万级别的,是不是更慢
应该是要换个效率更高的写法