参考:https://blog.csdn.net/zw__chen/article/details/82806900
一、DataLoader
1、DataLoader 是 torch 给你用来包装你的数据的工具. 所以你要将自己的 (numpy array 或其他) 数据形式装换成 Tensor, 然后再放进这个包装器中.
2、Dataset是一个抽象类,不能实例化,要先继承。DataLoader可以直接实例化
3、Dataset有内置数据集:
这些内置数据集有对应的__getitem__()和__len__()方法
二、内置数据集的使用
比如:MNIST数据集
# 可以选择是否将数据集下载到本地。要把数据转化成tensor,DataLoader才能处理
train_data = torchvision.datasets.MNIST(
root = './mnist', train = True,
transform = torchvision.transforms.ToTensor(),
download = True
)
test_data = torchvision.datasets.MNIST(
root = './mnist', train = False,
transform = torchvision.transforms.ToTensor(),
download = True
)
train_loader = Data.DataLoader(dataset=train_data, batch_size=50, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=50, shuffle=False)
三、DataLoader加载非内置数据集
如果要使用其他数据集,需要继承Dataset类,并对以上两个方法进行重写
本文使用Titanic数据集:https://www.kaggle.com/c/titanic/data
把数据集下载到当前代码的同级目录下【读取csv数据集时,可以用np.loadtxt或pd.read_csv,loadtxt要指定分割符,且引号中的逗号也可能被识别成分隔符,所以这里我用的是read_csv】
1、基本包的导入
import numpy as np
import pandas