部分转发
https://www.cnblogs.com/demo-deng/p/10623334.html
PyTorch 中的数据类型 torch.utils.data.DataLoader
数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。
在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。
"""
批训练,把数据变成一小批一小批数据进行训练。
DataLoader就是用来包装所使用的数据,每次抛出一批数据
"""
import torch
import torch.utils.data as Data
BATCH_SIZE = 5
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
# 把数据放在数据库中
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
# 从数据库中每次抽出batch size个样本
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2,
)
def show_batch():
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
# training
print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
if __name__ == '__main__':
show_batch()

根据设置每个epoch 进行shuffle
本文详细介绍了PyTorch中DataLoader的功能与用法,DataLoader是PyTorch提供的数据加载工具,能结合数据集和取样器,实现数据的批处理和多线程读取,适用于模型训练过程中的数据管理,通过示例展示了如何使用DataLoader将数据集分为多个批次,进行高效的数据迭代。
2508

被折叠的 条评论
为什么被折叠?



