PyTorch官方Tutorials Datasets & DataLoders

PyTorch官方Tutorials

跟着PyTorch官方Tutorials码的,便于理解自己稍有改动代码并添加注释,IDE用的jupyter notebook

链接: tutorials-Datasets & DataLoders

DATASETS & DATALOADERS

Code for processing data samples can get messy and hard to maintain; we ideally want our dataset code to be decoupled from our model training code for better readability and modularity. PyTorch provides two data primitives:

torch.utils.data.DataLoader
torch.utils.data.Dataset

that allow you to use pre-loaded datasets as well as your own data.

  • Dataset stores the samples and their corresponding labels
  • DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

PyTorch domain libraries provide a number of pre-loaded datasets (such as FashionMNIST) that subclass torch.utils.data.Dataset and implement functions specific to the particular data. They can be used to prototype and benchmark your model. You can find them here: Image Datasets, Text Datasets, and Audio Datasets

Loading a Dataset

Here is an example of how to load the Fashion-MNIST dataset from TorchVision. Fashion-MNIST is a dataset of Zalando’s article images consisting of of 60,000 training examples and 10,000 test examples. Each example comprises a 28×28 grayscale image and an associated label from one of 10 classes.

We load the FashionMNIST Dataset with the following parameters:
  • root is the path where the train/test data is stored,
  • train specifies training or test dataset,
  • download=True downloads the data from the internet if it’s not available at root.
  • transform and target_transform specify the feature and label transformations
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
#从FashionMNIST加载训练数据集
training_data=datasets.FashionMNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)
#training中数据为 三维tensor,int
#即 img label,其中img 为二维tensor
training_data[0]
(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0510,
           0.2863, 0.0000, 0.0000, 0.0039, 0.0157, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0039, 0.0039, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.1412, 0.5333,
           0.4980, 0.2431, 0.2118, 0.0000, 0.0000, 0.0000, 0.0039, 0.0118,
           0.0157, 0.0000, 0.0000, 0.0118],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0000, 0.4000, 0.8000,
           0.6902, 0.5255, 0.5647, 0.4824, 0.0902, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0471, 0.0392, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6078, 0.9255,
           0.8118, 0.6980, 0.4196, 0.6118, 0.6314, 0.4275, 0.2510, 0.0902,
           0.3020, 0.5098, 0.2824, 0.0588],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.2706, 0.8118, 0.8745,
           0.8549, 0.8471, 0.8471, 0.6392, 0.4980, 0.4745, 0.4784, 0.5725,
           0.5529, 0.3451, 0.6745, 0.2588],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0039, 0.0039, 0.0039, 0.0000, 0.7843, 0.9098, 0.9098,
           0.9137, 0.8980, 0.8745, 0.8745, 0.8431, 0.8353, 0.6431, 0.4980,
           0.4824, 0.7686, 0.8980, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7176, 0.8824, 0.8471,
           0.8745, 0.8941, 0.9216, 0.8902, 0.8784, 0.8706, 0.8784, 0.8667,
           0.8745, 0.9608, 0.6784, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7569, 0.8941, 0.8549,
           0.8353, 0.7765, 0.7059, 0.8314, 0.8235, 0.8275, 0.8353, 0.8745,
           0.8627, 0.9529, 0.7922, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0039, 0.0118, 0.0000, 0.0471, 0.8588, 0.8627, 0.8314,
           0.8549, 0.7529, 0.6627, 0.8902, 0.8157, 0.8549, 0.8784, 0.8314,
           0.8863, 0.7725, 0.8196, 0.2039],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0235, 0.0000, 0.3882, 0.9569, 0.8706, 0.8627,
           0.8549, 0.7961, 0.7765, 0.8667, 0.8431, 0.8353, 0.8706, 0.8627,
           0.9608, 0.4667, 0.6549, 0.2196],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0157, 0.0000, 0.0000, 0.2157, 0.9255, 0.8941, 0.9020,
           0.8941, 0.9412, 0.9098, 0.8353, 0.8549, 0.8745, 0.9176, 0.8510,
           0.8510, 0.8196, 0.3608, 0.0000],
          [0.0000, 0.0000, 0.0039, 0.0157, 0.0235, 0.0275, 0.0078, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.8863, 0.8510, 0.8745,
           0.8706, 0.8588, 0.8706, 0.8667, 0.8471, 0.8745, 0.8980, 0.8431,
           0.8549, 1.0000, 0.3020, 0.0000],
          [0.0000, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.2431, 0.5686, 0.8000, 0.8941, 0.8118, 0.8353, 0.8667,
           0.8549, 0.8157, 0.8275, 0.8549, 0.8784, 0.8745, 0.8588, 0.8431,
           0.8784, 0.9569, 0.6235, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.1725, 0.3216, 0.4196,
           0.7412, 0.8941, 0.8627, 0.8706, 0.8510, 0.8863, 0.7843, 0.8039,
           0.8275, 0.9020, 0.8784, 0.9176, 0.6902, 0.7373, 0.9804, 0.9725,
           0.9137, 0.9333, 0.8431, 0.0000],
          [0.0000, 0.2235, 0.7333, 0.8157, 0.8784, 0.8667, 0.8784, 0.8157,
           0.8000, 0.8392, 0.8157, 0.8196, 0.7843, 0.6235, 0.9608, 0.7569,
           0.8078, 0.8745, 1.0000, 1.0000, 0.8667, 0.9176, 0.8667, 0.8275,
           0.8627, 0.9098, 0.9647, 0.0000],
          [0.0118, 0.7922, 0.8941, 0.8784, 0.8667, 0.8275, 0.8275, 0.8392,
           0.8039, 0.8039, 0.8039, 0.8627, 0.9412, 0.3137, 0.5882, 1.0000,
           0.8980, 0.8667, 0.7373, 0.6039, 0.7490, 0.8235, 0.8000, 0.8196,
           0.8706, 0.8941, 0.8824, 0.0000],
          [0.3843, 0.9137, 0.7765, 0.8235, 0.8706, 0.8980, 0.8980, 0.9176,
           0.9765, 0.8627, 0.7608, 0.8431, 0.8510, 0.9451, 0.2549, 0.2863,
           0.4157, 0.4588, 0.6588, 0.8588, 0.8667, 0.8431, 0.8510, 0.8745,
           0.8745, 0.8784, 0.8980, 0.1137],
          [0.2941, 0.8000, 0.8314, 0.8000, 0.7569, 0.8039, 0.8275, 0.8824,
           0.8471, 0.7255, 0.7725, 0.8078, 0.7765, 0.8353, 0.9412, 0.7647,
           0.8902, 0.9608, 0.9373, 0.8745, 0.8549, 0.8314, 0.8196, 0.8706,
           0.8627, 0.8667, 0.9020, 0.2627],
          [0.1882, 0.7961, 0.7176, 0.7608, 0.8353, 0.7725, 0.7255, 0.7451,
           0.7608, 0.7529, 0.7922, 0.8392, 0.8588, 0.8667, 0.8627, 0.9255,
           0.8824, 0.8471, 0.7804, 0.8078, 0.7294, 0.7098, 0.6941, 0.6745,
           0.7098, 0.8039, 0.8078, 0.4510],
          [0.0000, 0.4784, 0.8588, 0.7569, 0.7020, 0.6706, 0.7176, 0.7686,
           0.8000, 0.8235, 0.8353, 0.8118, 0.8275, 0.8235, 0.7843, 0.7686,
           0.7608, 0.7490, 0.7647, 0.7490, 0.7765, 0.7529, 0.6902, 0.6118,
           0.6549, 0.6941, 0.8235, 0.3608],
          [0.0000, 0.0000, 0.2902, 0.7412, 0.8314, 0.7490, 0.6863, 0.6745,
           0.6863, 0.7098, 0.7255, 0.7373, 0.7412, 0.7373, 0.7569, 0.7765,
           0.8000, 0.8196, 0.8235, 0.8235, 0.8275, 0.7373, 0.7373, 0.7608,
           0.7529, 0.8471, 0.6667, 0.0000],
          [0.0078, 0.0000, 0.0000, 0.0000, 0.2588, 0.7843, 0.8706, 0.9294,
           0.9373, 0.9490, 0.9647, 0.9529, 0.9569, 0.8667, 0.8627, 0.7569,
           0.7490, 0.7020, 0.7137, 0.7137, 0.7098, 0.6902, 0.6510, 0.6588,
           0.3882, 0.2275, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1569,
           0.2392, 0.1725, 0.2824, 0.1608, 0.1373, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000]]]),
 9)
#从FashionMNIST加载测试数据集
test_data=datasets.FashionMNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor()
)

Iterating and Visualizing the Dataset

We can index Datasets manually like a list: training_data[index]. We use matplotlib to visualize some samples in our training data.

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}

#8 inches * 8 inches
figure = plt.figure(figsize=(8, 8))
cols,rows=3,3

for i in range(1,cols*rows+1):
    # len(training_data)为随机数上界,size(tuple),单元素元组(1,),randint返回的是tensor,用item()转为数字
    sample_idx=torch.randint(0,len(training_data),size=(1,)).item()
    
    print(labels_map[training_data[sample_idx][1]])
    
    img,label=training_data[sample_idx]
    #rows行,cols列 第i个位置
    figure.add_subplot(rows,cols,i)
    plt.title(labels_map[label])
    plt.axis('off')
    
    #numpy.squeeze(a,axis = None)
    #https://blog.csdn.net/qq_38675570/article/details/80048650
    # 1)a表示输入的数组;
     #2)axis用于指定需要删除的维度,但是指定的维度必须为单维度,否则将会报错;
     #3)axis的取值可为None 或 int 或 tuple of ints, 可选。若axis为空,则删除所有单维度的条目;
     #4)返回值:数组
     #5) 不会修改原数组;
    
    #从数组的形状中删除单维度条目,即把shape中为1的维度去掉
    
    plt.imshow(img.squeeze())
plt.show()

img.type()
#'torch.FloatTensor'
Sneaker
Trouser
Pullover
Trouser
Shirt
Bag
Ankle Boot
Dress
Bag

在这里插入图片描述

'torch.FloatTensor'

Creating a Custom Dataset for your files

A custom Dataset class must implement three functions:

  • _init_
  • _len_
  • _getitem_

Take a look at this implementation; the FashionMNIST images are stored in a directory img_dir, and their labels are stored separately in a CSV file annotations_file.

In the next sections, we’ll break down what’s happening in each of these functions.

import os
import pandas as pd
from torchvision.io import read_image

#继承 torch.utils.data.Daraset
class CustomImageDataset(Dataset):
    #初始化对象
    def __init__(self , annotations_file , img_dir , transForm=None , target_transform=None):
        self.img_labels=pd.read_csv(annotations_file)
        self.img_dir=img_dir
        self.transform=transform
        self.target_transform=target_transform
        
    #返回数据集中数据个数
    def __len__(self):
        return len(self.img_labels)
    
    #返回数据集中的一项,包括:图片,标签 为sample字典
    def __getitem__(self,idx):
        #iloc:根据标签的所在位置,从0开始计数,先选取行再选取列
        #loc:根据DataFrame的具体标签选取行列,同样是先行标签,后列标签
        
        img_path=os.path.join(self.img_dir , self.img_labels.iloc[idx,0])
        image=read_image(img_path)
        
        #label从img_labels的第idx行,第1列取出
        label=self.img_labels.iloc[idx,1]
        
        #transform见tutorials下一节
        if self.transform:
            image=self.transform(image)
        if self.target_transform:
            label=self.target_transform(label)
        sample={"image:" , image , "label:" , label}
        return sample

_init_

The _init_ function is run once when instantiating the Dataset object. We initialize the directory containing the images, the annotations file, and both transforms (covered in more detail in the next section).

The labels.csv file looks like:

tshirt1.jpg, 0
tshirt2.jpg, 0

ankleboot999.jpg, 9

def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

_len_

The _len_ function returns the number of samples in our dataset.

Example:

def __len__(self):
    return len(self.img_labels)

_getitem_

The _getitem_ function loads and returns a sample from the dataset at the given index idx. Based on the index, it identifies the image’s location on disk, converts that to a tensor using read_image, retrieves the corresponding label from the csv data in self.img_labels, calls the transform functions on them (if applicable), and returns the tensor image and corresponding label in a Python dict.

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    sample = {"image": image, "label": label}
    return sample

Preparing your data for training with DataLoaders

The Datasetretrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessingto speed up data retrieval.

DataLoaderis an iterable that abstracts this complexity for us in an easy API.

from torch.utils.data import DataLoader

train_dataloader=DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader=DataLoader(test_data,batch_size=64,shuffle=True)

Iterate through the DataLoader

We have loaded that dataset into the Dataloaderand can iterate through the dataset as needed. Each iteration below returns a batch of train_featuresand train_labels(containing batch_size=64features and labels respectively). Because we specified shuffle=True, after we iterate over all batches the data is shuffled (for finer-grained control over the data loading order, take a look at Samplers.

#Display image and label
train_features,train_labels=next(iter(train_dataloader))
print(f'Feature batch shape: {train_features.size()}')
print(f'Labels batch shape:{train_labels.size()}')

img=train_features[0].squeeze()
label=train_labels[0]
plt.imshow(img)
plt.show()
print(f'Label: {label} {labels_map[label.item()]}')
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape:torch.Size([64])

在这里插入图片描述

Label: 4 Coat

torch.utils.data.dataloader._SingleProcessDataLoaderIter

总结:

dataset可以使用torch自带的数据集,比如FashionMINIST
__也可以自己重写Dataset类,重写时要实现三个函数: __

def init(self , annotations_file , img_dir , transForm=None , target_transform=None):

def len(self):

def getitem(self,idx):

Dataset定义完后要装入Dataloader,装入后就可以用迭代器next(iter(Dataloader))对数据集进行大小为batch_size的迭代

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
【优质项目推荐】 1、项目代码均经过严格本地测试,运行OK,确保功能稳定后才上传平台。可放心下载并立即投入使用,若遇到任何使用问题,随时欢迎私信反馈与沟通,博主会第一时间回复。 2、项目适用于计算机相关专业(如计科、信息安全、数据科学、人工智能、通信、物联网、自动化、电子信息等)的在校学生、专业教师,或企业员工,小白入门等都适用。 3、该项目不仅具有很高的学习借鉴价值,对于初学者来说,也是入门进阶的绝佳选择;当然也可以直接用于 毕设、课设、期末大作业或项目初期立项演示等。 3、开放创新:如果您有一定基础,且热爱探索钻研,可以在此代码基础上二次开发,进行修改、扩展,创造出属于自己的独特应用。 欢迎下载使用优质资源!欢迎借鉴使用,并欢迎学习交流,共同探索编程的无穷魅力! 基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip 基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip 基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip
提供的源码资源涵盖了安卓应用、小程序、Python应用和Java应用等多个领域,每个领域都包含了丰富的实例和项目。这些源码都是基于各自平台的最新技术和标准编写,确保了在对应环境下能够无缝运行。同时,源码中配备了详细的注释和文档,帮助用户快速理解代码结构和实现逻辑。 适用人群: 这些源码资源特别适合大学生群体。无论你是计算机相关专业的学生,还是对其他领域编程感兴趣的学生,这些资源都能为你提供宝贵的学习和实践机会。通过学习和运行这些源码,你可以掌握各平台开发的基础知识,提升编程能力和项目实战经验。 使用场景及目标: 在学习阶段,你可以利用这些源码资源进行课程实践、课外项目或毕业设计。通过分析和运行源码,你将深入了解各平台开发的技术细节和最佳实践,逐步培养起自己的项目开发和问题解决能力。此外,在求职或创业过程中,具备跨平台开发能力的大学生将更具竞争力。 其他说明: 为了确保源码资源的可运行性和易用性,特别注意了以下几点:首先,每份源码都提供了详细的运行环境和依赖说明,确保用户能够轻松搭建起开发环境;其次,源码中的注释和文档都非常完善,方便用户快速上手和理解代码;最后,我会定期更新这些源码资源,以适应各平台技术的最新发展和市场需求。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值