本文记录学习过程中遇到的问题、我的解决过程以及学习心得,如有错误之处,欢迎指正!
在学习用pytorch进行数据批处理的过程中用到了torch.utils.data.TensorDataset()和torch.utils.data.DataLoader()函数,练习的代码如下:
import torch
import torch.utils.data as Data
torch.manual_seed(1) # 这句有关生成随机数,他会使得随机生成的结果是确定的
BATCH_SIZE= 5 # 设置批次训练数量
# 定义数据
x = torch.linspace(1, 10, steps=10) # torch.linspace()线性等分向量,前两个参数是向量的开始和结束值,steps是分割出的点数,缺省值100
y = torch.linspace(10, 1, steps=10) # x,y都是十维向量
torch_dataset = Data.TensorDataset(x, y) # x,y对应整合进数据集,应该是一个二维数据的队列(10*2矩阵)
loader = Data.DataLoader(
dataset=torch_dataset, # 加载数据集
batch_size=BATCH_SIZE, # 批次大小
shuffle=True, # 是否打乱顺序训练
num_workers=2 # 设置线程数
)
def show_batch():
for epoch in range(3): # 进行三轮训练
for step, (batch_x, batch_y) in enumerate(loader): #