批量处理
import torch
import torch.utils.data as Data
batch_size = 5
if __name__ == '__main__':
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2,
)
for epoch in range(3): # 整体训练次数
for step, (batch_x, batch_y) in enumerate(loader):
print('Epoch:', epoch, '|step:', step,'|batch x:',
batch_x.numpy(), '|batch y:', batch_y.numpy())
输出结果为:
Epoch: 0 |step: 0 |batch x: [9. 3. 8. 6. 2.] |batch y: [2. 8. 3. 5. 9.]
Epoch: 0 |step: 1 |batch x: [ 1. 7. 10. 4. 5.] |batch y: [10. 4. 1. 7. 6.]
Epoch: 1 |step: 0 |batch x: [ 9. 10. 8. 2. 7.] |batch y: [2. 1. 3. 9. 4.]
Epoch: 1 |step: 1 |batch x: [5. 3. 6. 4. 1.] |batch y: [ 6. 8. 5. 7. 10.]
Epoch: 2 |step: 0 |batch x: [7. 1. 2. 6. 5.] |batch y: [ 4. 10. 9. 5. 6.]
Epoch: 2 |step: 1 |batch x: [ 9. 10. 8. 4. 3.] |batch y: [2. 1. 3. 7. 8.]
设置参数num_workers=?
如果不加
if __name__ == '__main__':
就会出错。
在这里num_workers为线程个数的设置。可以查看自己电脑的线程个数。
import threading
def main():
print(threading.active_count())
if __name__ == '__main__':
main()
输出结果为1.
所以我以为是我电脑只有一个threading,当设置num_workers=1时,依旧会出错。
所以
if __name__ == '__main__':
更像是一个线程的入口,只有添加之后,线程才可以工作。如果想避免这些麻烦,也可以直接将这一参数去掉。
import torch
import torch.utils.data as Data
batch_size = 5
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=batch_size,
shuffle=True,
# num_workers=1,
)
for epoch in range(3): # 整体训练次数
for step, (batch_x, batch_y) in enumerate(loader):
print('Epoch:', epoch, '|step:', step,'|batch x:',
batch_x.numpy(), '|batch y:', batch_y.numpy())