利用pytorch框架训练,有一个重要步骤就是批量训练
要用到torch.nn.data模块
import torch.nn.data as Data
创建一个TensorDataset对象存放训练数据x,y
mydataset=Data.TensorDataset(x,y)
创建一个DataLoader对象加载数据,设置dataset,batch_size,shuffle,num_work等参数
BATCH_SIZE=5 #设置批量训练的数量,超参数用大写字母表示
data_loader=Data.Dataloder(dataset=mydataset, #数据集
batch_size=BATCH_SIZE, #批量训练数量
shuffle=True,#(打乱数据顺序)
num_workers=2 #(2个线程)
数据集和数据加载器都写好后,可以开始训练了
for