1. pytorch 对于数据的标签要求是长整形,因此要对标签进行转换
train_label = train_label.long()
2.对于数据的特征部分,要转换为tensor 形式,可以通过torch.tensor 将数据从numpy 转为tensor
train_fea = torch.tensor(train_fea, dtype=torch.float32)
3.封装数据
data = torch.utils.data.TensorDataset(train_fea, train_label)#(特征,标签)
4.封装成第三步的形式,就可以采用torch 中的数据加载器为模型提供数据,数据加载器可以自动分批喂给模型数据
train_loader = torch.utils.data.DataLoader(data,batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)