先来看一下:出现init_epoch该函数的代码段如下:
from torchtext import data
def mt_iterator(opt, train=True):
DE = data.Field(eos_token=EOS, lower=True, preprocessing=(lambda x: x[::-1]) if opt.reverse else None)
EN = data.Field(init_token=EOS, eos_token=EOS, lower=True)
train_data, val_data, test_data = datasets.TranslationDataset.splits(path=opt.data, train='train',
validation='valid', test='test',
exts=('.input', '.output'),
fields=(DE, EN))
if train:
iterator = data.BucketIterator.splits(
(train_data, val_data), batch_size=opt.batch_size, device=0 if opt.cuda else -1
)
iterator[0].repeat = False
else:
iterator = data.Iterator( test_data, batch_size=opt.batch_size,
device=0 if opt.cuda else -1, train=False, shuffle=False, sort=False
)
return iterator, EN.vocab
iterator, vocab_en = mt_iterator(opt)
train_iter = iterator[0]
train_iter.init_epoch()
由于init_epoch函数出自torchtext.data的iter,查看此源代码。
Iterator:迭代器函数,用来加载数据集中的批次的数据。
init_epoch()函数在最下方,其作用是为每次epoch创建一个batch生成器。train时的每一次epoch初始化一次。
class Iterator(object): | |
"""Defines an iterator that loads batches of data from a Dataset. | |
Attributes: | |
batch_size_fn: Function of three arguments (new example to add, current | |
count of examples in the batch, and current effective batch size) | |
that returns the new effective batch size resulting from adding | |
that example to a batch. This is useful for dynamic batching, where | |
this function would add to the current effective batch size the | |
number of tokens in the new example. | |
sort_key: 用于排序的键,可将相似长度的放一起,减少padding | |
train: 是否这个iterator表示训练集. | |
repeat: Whether to repeat the iterator for multiple epochs.【默认train】 | |
shuffle: 每两次epoch间是否打乱【默认train】 | |
sort: 是否按照self.sort_key排序.【默认not train】 | |
sort_within_batch: Whether to sort (in descending order according to | |
self.sort_key) within each batch. 默认参照self.sort. | |
If self.sort is True and this is False, the batch is left in the | |
original (ascending) sorted order. | |
device: Use -1 for CPU ,默认GPU | |
""" | |
def __init__(self, dataset, batch_size, sort_key=None,device=None, | |
batch_size_fn=None,train=True, | |
repeat=None,shuffle=None,sort=None, | |
sort_within_batch=None): | |
self.batch_size, self.train, self.dataset= batch_size, train, dataset | |
self.batch_size_fn = batch_size_fn | |
self.iterations = 0 | |
self.repeat = train if repeat is None else repeat | |
self.shuffle = train if shuffle is None else shuffle | |
self.sort = not train if sort is None else sort | |
if sort_within_batch is None: | |
self.sort_within_batch = self.sort | |
else: | |
self.sort_within_batch = sort_within_batch | |
if sort_key is None: | |
self.sort_key = dataset.sort_key | |
else: | |
self.sort_key = sort_key | |
self.device = device | |
self.random_shuffler = RandomShuffler() | |
# For state loading/saving only | |
self._iterations_this_epoch= 0 | |
self._random_state_this_epoch= None | |
self._restored_from_state = False | |
@classmethod | |
def splits(cls, datasets, batch_sizes=None,**kwargs): | |
"""Create Iterator objects for multiple splits of a dataset. | |
Arguments: | |
datasets: Tuple of Dataset objects corresponding to the splits. The | |
first such object should be the train set. | |
batch_sizes: Tuple of batch sizes to use for the different splits, | |
or None to use the same batch_size for all splits. | |
Remaining keyword arguments: Passed to the constructor of the | |
iterator class being used. | |
""" | |
if batch_sizes is None: | |
batch_sizes = [kwargs.pop('batch_size')]* len(datasets) | |
ret = [] | |
for i in range(len(datasets)): | |
train = i == 0 | |
ret.append(cls( | |
datasets[i], batch_size=batch_sizes[i],train=train,**kwargs)) | |
return tuple(ret) | |
def data(self): | |
"""Return the examples in the dataset in order, sorted, or shuffled.""" | |
if self.sort: | |
xs = sorted(self.dataset,key=self.sort_key) | |
elif self.shuffle: | |
xs = [self.dataset[i]for i in self.random_shuffler(range(len(self.dataset)))] | |
else: | |
xs = self.dataset | |
return xs | |
init_epoch()函数 | def init_epoch(self): |
| """Set up the batch generator for a new epoch."""每轮重新迭代 |
| |
| if self._restored_from_state:如果 ,则获取保存的参数 |
| self.random_shuffler.random_state= self._random_state_this_epoch |
| else:否则 默认为第一次迭代,则自行初始化 |
| self._random_state_this_epoch= self.random_shuffler.random_state |
| |
| self.create_batches()创建第0次迭代的batch数据 |
| |
| if self._restored_from_state:如果是恢复状态,则将此标志位归于默认值 |
| self._restored_from_state= False |
| else:默认为否 |
| self._iterations_this_epoch= 0 说明是刚开始,是第零次迭代 |
| |
| if not self.repeat:迭代器在多轮之间不重复,则 |
| self.iterations = 0 |
| |
def create_batches(self): | |
self.batches = batch(self.data(),self.batch_size, self.batch_size_fn) | |
@property | |
def epoch(self): | |
return self.iterations / len(self) |