代码部分
def data_iter(batch_size, features, labels): num_examples = len(features) indices = list(range(num_examples)) # 这些样本是随机读取的,没有特定的顺序 random.shuffle(indices) for i in range(0, num_examples, batch_size): batch_indices = torch.tensor( indices[i: min(i + batch_size, num_examples)]) yield features[batch_indices], labels[batch_indices]
batch_size = 10 for X, y in data_iter(batch_size, features, labels): print(X, '\n', y) break
简单来说 yield函数相当于一个generator,返回一个值,下次迭代时从此位置开始。
每次产生随机顺序的features和labels,不断调用,不断地返回,直到全部完成为止。
运行结果如下:
tensor([[ 0.1649, -1.1651], [-2.0755, -1.0165], [-0.2189, 0.7607], [ 0.6833, 0.3537], [-0.2736, -2.0485], [-0.3026, 0.9771], [ 2.4795, 0.6881], [-0.2045, -0.8509], [-0.1353, 0.5476], [ 0.3371, -0.0479]]) tensor([[ 8.4901], [ 3.5015], [ 1.1779], [ 4.3752], [10.6125], [ 0.2845], [ 6.8094], [ 6.6776], [ 2.0598], [ 5.0189]])