代码和解释如下,最后附上了输出结果:
# coding=gbk
import torch
import torch.utils.data as Data #将数据分批次需要用到它
torch.manual_seed(1) # 种子,可复用
BATCH_SIZE = 8 #设置批次大小
x = torch.linspace(1, 15, 15) # 1到15共15个点
y = torch.linspace(15, 1, 15) # 15到1共15个点
torch_dataset = Data.TensorDataset(x, y) #将x,y读取,转换成Tensor格式
loader = Data.DataLoader(
dataset=torch_dataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # 最新批数据
shuffle=True, # 是否随机打乱数据
num_workers=2, # 用于加载数据的子进程
)
def show_batch():
for epoch in range(3): # 对整个数据集进行3次培训
for step, (batch_x, batch_y) in enumerate(loader): # 每个训练步骤
# 此处省略一些训练数据步骤...
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
show_batch()
'''
(1)每次训练5个数据,打乱数据;每进行一次完整的训练需要进行3个训练步骤:
Epoch: 0 | Step: 0 | batch x: [10. 12. 9. 5. 1.] | batch y: [ 6. 4. 7. 11. 15.]
Epoch: 0 | Step: 1 | batch x: [ 7. 15. 8. 13. 3.] | batch y: [ 9. 1. 8. 3. 13.]
Epoch: 0 | Step: 2 | batch x: [ 2. 6. 14. 4. 11.] | batch y: [14. 10. 2. 12. 5.]
Epoch: 1 | Step: 0 | batch x: [ 3. 10. 8. 13. 2.] | batch y: [13. 6. 8. 3. 14.]
Epoch: 1 | Step: 1 | batch x: [ 5. 4. 12. 14. 1.] | batch y: [11. 12. 4. 2. 15.]
Epoch: 1 | Step: 2 | batch x: [15. 9. 11. 6. 7.] | batch y: [ 1. 7. 5. 10. 9.]
Epoch: 2 | Step: 0 | batch x: [ 8. 7. 3. 10. 12.] | batch y: [ 8. 9. 13. 6. 4.]
Epoch: 2 | Step: 1 | batch x: [ 6. 13. 9. 4. 15.] | batch y: [10. 3. 7. 12. 1.]
Epoch: 2 | Step: 2 | batch x: [14. 2. 5. 1. 11.] | batch y: [ 2. 14. 11. 15. 5.]
(2)每次训练8个数据,打乱数据;每进行一次完整的训练需要进行2个训练步骤,一次8个数据,一次7个数据:
Epoch: 0 | Step: 0 | batch x: [10. 12. 9. 5. 1. 7. 15. 8.] | batch y: [ 6. 4. 7. 11. 15. 9. 1. 8.]
Epoch: 0 | Step: 1 | batch x: [13. 3. 2. 6. 14. 4. 11.] | batch y: [ 3. 13. 14. 10. 2. 12. 5.]
Epoch: 1 | Step: 0 | batch x: [ 3. 10. 8. 13. 2. 5. 4. 12.] | batch y: [13. 6. 8. 3. 14. 11. 12. 4.]
Epoch: 1 | Step: 1 | batch x: [14. 1. 15. 9. 11. 6. 7.] | batch y: [ 2. 15. 1. 7. 5. 10. 9.]
Epoch: 2 | Step: 0 | batch x: [ 8. 7. 3. 10. 12. 6. 13. 9.] | batch y: [ 8. 9. 13. 6. 4. 10. 3. 7.]
Epoch: 2 | Step: 1 | batch x: [ 4. 15. 14. 2. 5. 1. 11.] | batch y: [12. 1. 2. 14. 11. 15. 5.]
'''