题解
from typing import List
def data_pipeline(data: List[List[int]], batch_size: int) -> List[List[int]]:
# 在此函数中完成 mini-batch
for i in range(0,len(data),batch_size):
batch_data = []
print(i)
for k in range(batch_size):
if i+k >= len(data):
break
else:
batch_data.append(data[i+k])
yield batch_data
data = [[1, 2],
[1, 3],
[3, 5],
[2, 1],
[3, 3]]
batch_size = 2
for i in data_pipeline(data, batch_size):
print(i)