Python深度学习:数据加载
1、Dataset基类torch.utils.data.Dataset
以一个案例来描述如何使用Dataset来加载数据。
数据来源:https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection
数据介绍:用于骚扰短信识别的经典数据集,每行开头用ham和spam标识正常短信和骚扰短信。
from torch.utils.data import Dataset
data_path = "./smsspamcollection/SMSSpamCollection"
class MyDataset(Dataset):
def __init__(self):
self.lines = open(data_path, encoding="utf8").readlines()
def __getitem__(self, item):
cur_line = self.lines[item].strip()
label = cur_line[:4].strip()
content = cur_line[4:].strip()
return label, content
def __len__(self):
return len(self.lines)
if __name__ == '__main__':
myDataset = MyDataset()
for i in range(len(myDataset)):
print(f"No.{i + 1}: {myDataset[i]}")
print(f"total:{len(myDataset)}")
2、迭代数据集
DataLoader(dataset=myDataset, batch_size=2, shuffle=True)
from torch.utils.data import DataLoader
myDataset = MyDataset()
data_loader = DataLoader(dataset=myDataset, batch_size=2, shuffle=True)
if __name__ == '__main__':
for i in data_loader:
print(i)
3、pytorch中自带的数据集
- torchvision:图像,在torchvision.datasets
- torchtext:文本,在torchtext.datasets
torchvision.datasets中的MNIST
import torchvision
dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
print(dataset[0]) # (<PIL.Image.Image image mode=L size=28x28 at 0x1A106A67438>, 5)
img = dataset[0][0]
img.show()