pytorch中的数据加载(dataset基类,以及pytorch自带数据集)

pytorch中的数据加载

模型中使用数据加载器的目的

在前面的线性回归模型中,我们使用的数据很少,所以直接把全部数据放到模型中去使用。

但是在深度学习中,数据量通常都非常多、非常大的,如此大量的数据,不可能一次性的在模型中进行向前的计算和反向传播,经常我们会对整个数据进行随机的打乱顺序,把数据处理成一个个的batch,同时还会对数据进行预处理。

所以,我们接下来学习pytorch中的数据加载的方法

数据集类

Dataset基类介绍

在torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,我们能够非常快速的实现对数据的加载。

torch.utils.data.Dataset的源码如下:

class Dataset(object):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

    def __len__(self):
    	raise NotImplementedError

可知:我们需要在自定义的数据集类中继承Dataset类,同时还需要实现两个方法:

  1. __len__方法,能够实现通过全局的len()方法获取其中的元素个数
  2. __getitem__方法,能够通过传入索引的方式获取数据,获取索引位置的一条数据,例如通过dataset[i]获取其中的第i条数据

数据加载案例

下面通过一个例子来看看如何使用Dataset来加载数据

数据来源:http://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection

数据介绍:SMS Spam Collection是用于骚扰短信识别的经典数据集,完全来自真实短信内容,包括4831条正常短信和747条骚扰短信。正常短信和骚扰短信保存在一个文本文件中。每行完整记录一条短信内容,每行开头通过ham和spam标识正常短信和骚扰短信。

数据实例举例:

ham Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...
ham    Ok lar... Joking wif u oni...
spam   Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
ham    U dun say so early hor... U c already then say...
ham    Nah I don't think he goes to usf, he lives around here though
spam   FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv
ham    Even my brother is not like to speak with me. They treat me like aids patent.
ham    As per your request 'Melle Melle (Oru Minnaminunginte Nurungu Vettam)' has been set as your callertune for all Callers. Press *9 to copy your friends Callertune
spam   WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.
spam   Had your mobile 11 months or more? U R entitled to Update to the latest colour mobiles with camera for Free! Call The Mobile Update Co FREE on 08002986030
ham    I'm gonna be home soon and i don't want to talk about this stuff anymore tonight, k? I've cried enough today.
> spam   SIX chances to win CASH! From 100 to 20,000 pounds txt> CSH11 and send to 87575. Cost 150p/day, 6days, 16+ TsandCs apply Reply HL 4 info
spam   URGENT! You have won a 1 week FREE membership in our £100,000 Prize Jackpot! Txt the word: CLAIM to No: 81010 T&C www.dbuk.net LCCLTD POBOX 4403LDNW1A7RW18
ham    I've been searching for the right words to thank you for this breather. I promise i wont take your help for granted and will fulfil my promise. You have been wonderful and a blessing at all times.
ham    I HAVE A DATE ON SUNDAY WITH WILL!!
spam   XXXMobileMovieClub: To use your credit, click the WAP link in the next txt message or click here>> http://wap. xxxmobilemovieclub.com?n=QJKGIGHJJGCBL
ham    Oh k...i'm watching here:)

实现如下:

from torch.utils.data import Dataset,DataLoader

data_path = r'data\SMSSpamCollection'#找到数据文件的位置

class My_DataSet(Dataset):
    def __init__(self): #初始化,完成对数据文件的读入,注意,我的电脑上运行的时候需要注明encoding格式为UTF-8,否则报错
        self.lines = open(data_path,'r',encoding='UTF-8').readlines()

    def __getitem__(self, index):
        #Python strip() 方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列。注意:该方法只能删除开头或是结尾的字符,不能删除中间部分的字符。
        cur_line = self.lines[index].strip()#用cur_line记录每一行
        lable = cur_line[:4].strip()#lable作为标签记录前面的部分ham和spam
        context = cur_line[4:].strip()#context作为内容记录短信的实际有效内容
        return lable,context

    def __len__(self):#获取每行的长度
        return len(self.lines)

my_database = My_DataSet()
if __name__ == '__main__':
    print(my_database[0])
    print(len(my_database))

输出结果:

('ham', 'Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...')
5574

数据加载器类

使用上述的方法能进行数据的读取,但是其中还有很多内容没有实现:

  • 批处理数据(Batching the data)
  • 打乱数据(Shuffling the data)
  • 使用多线程multiprocessing并行加载数据

在pytorch中torch.utils.data.DataLoder提供了上述的所用方法

DataLoader的使用方法示例:

from torch.utils.data import Dataset,DataLoader

data_path = r'data\SMSSpamCollection'#找到数据文件的位置

class My_DataSet(Dataset):
    def __init__(self): #初始化,完成对数据文件的读入,注意,我的电脑上运行的时候需要注明encoding格式为UTF-8,否则报错
        self.lines = open(data_path,'r',encoding='UTF-8').readlines()

    def __getitem__(self, index):
        #Python strip() 方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列。注意:该方法只能删除开头或是结尾的字符,不能删除中间部分的字符。
        cur_line = self.lines[index].strip()#用cur_line记录每一行
        lable = cur_line[:4].strip()#lable作为标签记录前面的部分ham和spam
        context = cur_line[4:].strip()#context作为内容记录短信的实际有效内容
        return lable,context

    def __len__(self):#获取每行的长度
        return len(self.lines)

my_database = My_DataSet()
my_dataloader = DataLoader(my_database,batch_size=2,shuffle=True,num_workers=2)
if __name__ == '__main__':
    #print(my_database[0])
    #print(len(my_database))
    for i,(lable,context) in enumerate(my_dataloader):
        print(i,lable,context)
        break

输出结果:

0 ('ham', 'ham') ("Alright. I'm out--have a good night!", 'Oh ok wait 4 me there... My lect havent finish')

再运行一次:

0 ('ham', 'ham') ("Yeah, that's fine! It's £6 to get in, is that ok?", "Pathaya enketa maraikara pa'")

可见里面数据确实被打乱了,第0条两次输出的数据不一样。

my_dataloader = DataLoader(my_database,batch_size=2,shuffle=True,num_workers=2)

其中参数含义:

  1. dateset:提前定义的Dataset实例
  2. batch_size:传入数据的batch的大小,常用128,256等,批处理的大小
  3. shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据
  4. num_workers:加载数据的线程数

注意:

  1. len(dataset) = 数据集的样本数
  2. len(dataloader) = math.ceil(样本数/batch_size),即向上取整

pytorch自带的数据集

pytorch中自带的数据集由两个上层api提供,分别是torchvision和torchtext

其中:

  1. torchvision提供了对图片数据处理相关的api和数据

    • 数据位置:torchvision.datasets,例如,torchvision.datasets.MNIST(手写数字图片数据)
  2. torchtext提供了对文本数据处理相关的api和数据

    • 数据位置:torchtext.datasets,例如:torchtext.datasets.IMDB(电影评论文本数据)

下面我们以Mnist手写数字为例,来看看pytorch如何加载其中自带的数据集

使用方法和之前一样:

  1. 准备好dataset实例
  2. 把dataset交给dataloader打乱顺序,组成batch

torchvision.datasets

torchvision.datasets中的数据集类(比如torchvision.datasets.MNIST)都是继承自dataset,意味着直接对torchvision.datasets.MNIST进行实例化就可以得到dataset的实例

但是MNIST API中的参数需要注意一下:

torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=None)
  1. root 参数表示文件存储的位置
  2. train:布尔型,表示的是使用训练集的数据还是测试集的数据
  3. download:布尔类型,表示是否需要下载数据到root目录
  4. transform:实现的对图片的处理函数

MINIST数据集的介绍

数据集的原始地址:http://yann.lecun.com/exdb/mnist

MINIST是由Yann leCun等人提供的免费的图像识别的数据集,其中包括60000个训练样本和10000个测试样本,其中图拍了的尺寸已经进行了标准化的处理,都是黑白的图像,大小为28*28

执行代码,下载数据,观察数据类型:

import torchvision
dataset = torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=None)
print(dataset[0])#打印这个数据集第一个元素的内容
print(dataset) #打印关于这个数据集的相关信息

运行结果:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz
99.3%99.3%Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw
100.1%Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz
113.5%Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz
38.8%39.2%Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw
100.4%Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw
180.4%Processing...
Done!
(<PIL.Image.Image image mode=L size=28x28 at 0x211C235D208>, 5)
Dataset MNIST
Number of datapoints: 60000
Root location: ./data
Split: Train

可见程序会从对应网站下载指定数据到我们指定的文件夹

在这里插入图片描述

再深入说一下

print(dataset[1])
dataset[1][0].show()

这里的运行结果是:

(<PIL.Image.Image image mode=L size=28x28 at 0x23E0D88C748>, 0)

且会通过show()方法显示出对应数据对应图片0
在这里插入图片描述

解释:dataset[1]表示Dataset这个数据集的第二个元素的内容是一个元组类型,元组内的内容分为两部分,<PIL.Image.Image image mode=L size=28x28 at 0x23E0D88C748>表示这是一个图片类型,说明了图片大小和存储位置等信息,0表示这个图片的数字内容是0,通过后面的show()方法也可以验证这一点。

更多Pytorch知识梳理,请参考: pytorch学习笔记

有问题请下方评论,转载请注明出处,并附有原文链接,谢谢!如有侵权,请及时联系。

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
好的,这里是一个PyTorch实现的水果图像识别算法,使用Fruit 360数据集。代码包含数据增强、批量化归一、学习率策略、权重衰减、梯度裁剪、Adm优化等内容,并配置了精度函数和图像类。请注意,代码数据集路径需要根据本地文件路径进行修改。 ```python import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import os # 设置随机种子 torch.manual_seed(2021) # 定义图像类 class ImageDataset(Dataset): def __init__(self, data_dir, transform=None): self.data_dir = data_dir self.filenames = os.listdir(data_dir) self.transform = transform def __len__(self): return len(self.filenames) def __getitem__(self, idx): img_path = os.path.join(self.data_dir, self.filenames[idx]) image = Image.open(img_path).convert('RGB') if self.transform: image = self.transform(image) label = int(self.filenames[idx].split('_')[0]) return image, label # 定义数据增强 train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.RandomRotation(degrees=15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # 定义数据集路径 train_data_dir = 'train/' test_data_dir = 'test/' # 训练集和测试集 train_dataset = ImageDataset(train_data_dir, transform=train_transform) test_dataset = ImageDataset(test_data_dir, transform=test_transform) # 定义批量化归一 batch_size = 32 train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 定义模型 class FruitClassification(nn.Module): def __init__(self): super(FruitClassification, self).__init__() self.conv1 = nn.Conv2d(3, 32, 3, padding=1) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(64 * 16 * 16, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 15) def forward(self, x): x = self.pool(nn.functional.relu(self.conv1(x))) x = self.pool(nn.functional.relu(self.conv2(x))) x = x.view(-1, 64 * 16 * 16) x = nn.functional.relu(self.fc1(x)) x = nn.functional.relu(self.fc2(x)) x = self.fc3(x) return x # 定义模型、损失函数、优化器 model = FruitClassification() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4) # 定义学习率策略 scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5) # 定义训练函数 def train(model, dataloader, criterion, optimizer, scheduler): model.train() running_loss = 0.0 correct = 0 total = 0 for i, data in enumerate(dataloader): inputs, labels = data inputs, labels = inputs.cuda(), labels.cuda() optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() # 定义梯度裁剪 nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) optimizer.step() running_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() train_loss = running_loss / len(dataloader) train_acc = correct / total # 更新学习率 scheduler.step(train_loss) return train_loss, train_acc # 定义测试函数 def test(model, dataloader, criterion): model.eval() running_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for i, data in enumerate(dataloader): inputs, labels = data inputs, labels = inputs.cuda(), labels.cuda() outputs = model(inputs) loss = criterion(outputs, labels) running_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() test_loss = running_loss / len(dataloader) test_acc = correct / total return test_loss, test_acc # 开始训练 num_epochs = 50 for epoch in range(num_epochs): train_loss, train_acc = train(model, train_dataloader, criterion, optimizer, scheduler) test_loss, test_acc = test(model, test_dataloader, criterion) print('Epoch: {}, Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f}'.format( epoch+1, train_loss, train_acc, test_loss, test_acc)) ``` 这个代码实现可以作为一个础模板,可以根据具体需求进行修改和优化。希望对您有所帮助!
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值