前言
迭代器
理解 Python 的迭代器是解读 PyTorch 中 torch.utils.data 模块的关键。
在 Dataset, Sampler 和 DataLoader 这三个类中都会用到 python 抽象类的魔法方法,包括__len__(self)
,__getitem__(self)
和 __iter__(self)
class Generator(Dataset):
def __init__(self, cfg, env):
self.data = env.get_batch_nodes(cfg.n_samples)
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return self.data.size(0)
__len__(self)
: 定义当被 len() 函数调用时的行为,一般返回迭代器中元素的个数 <