回归中使用的数据是torch的32位浮点型;分类问题中的预测标签是64位有符号整型数据。
一、数据类型的转换
train_x = torch.from_numpy(data_x.astype(np.float32))
# 训练数据浮点型32位
train_y = torch.from_numpy(data_y.astype(np.int64))
# 标签数据整型64位
二、整理数据
train_data = torch.utils.data.TensorDataset(train_x,train_y)
# 将x,y整理到一起
train_loader = torch.utils.data.DataLoader(
dataset = train_data, # 数据集
batch_size = 10, # 批处理样本大小
shuffle = True, # 迭代前打乱数据
num_workers = 1, # 进程使用数
)
以上是对于数组数据的部分预处理,针对图片数据和文本数据也有相应的处理方法。