前言:
今天我继续学习Pytorch框架的部分功能,分析了 数据加载器 train_dataloader 的数据结构。下面与大家一起分享分析过程,期望能给您带来一些帮助。
思路:
根据Pytorch官网上的tutorials部分介绍,首先下载训练和测试数据集并安装好,然后是编程分析数据加载器的数据结构。Pytorch框架共有两个数据加载器,这里分享的是第一个加载器 train_dataloader 的数据结构,第二个加载器 test_dataloader 的数据结构与前者的是一样的,不再赘述。
代码:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
from torchvision.transforms import ToTensor
from torchvision.utils import save_image
#下面的url讲了如何下载安装训练和测试数据集
#https://pytorch.org/vision/stable/generated/torchvision.datasets.FashionMNIST.html
# root (string) – Root directory of dataset
# where FashionMNIST/raw/train-images-idx3-ubyte
# and FashionMNIST/raw/t10k-images-idx3-ubyte exist.
training_data = datasets.FashionMNIST(
root="./data",
train=True,
download=False, #True,
transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
root="./data",
train=False,
download=False, #True,
transform=ToTensor(),
)
#一批一批地处理。一批图片的数量
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
print('\n\n\n(一)分析 train_dataloader 的数据结构')
print('###########################################################')
print('type(train_dataloader) = ',type(train_dataloader))
print('len(train_dataloader) = ',len(train_dataloader))
print(f'''分析:
train_dataloader 是一个可迭代对象,转换成列表后发现:
每个元素是 1 batch(批)数据,共包含有 {len(train_dataloader)} batch(批)数据。''')
b=2
print(f'\n下面是train_dataloader的前 {b} 个元素:')
print(f'list(train_dataloader)[:{b}] = ',list(train_dataloader)[:b])
print(f'\n下面分析前 {b} batch 数据:')
print('==============================================================\n')
#分批分析
i=0
for batch in train_dataloader:
print(f'\t 下面是 batch[{i}] 的数据:\n\t '+'-'*40)
print(f'\t batch[{i}] = ',batch)
print(f'\t type(batch[{i}]) = ',type(batch))
print(f'\t len(batch[{i}]) = ',len(batch))
print(f'\t 分析:batch[{i}] 是一个列表,包含两个tensor。')
print('\t '+'-'*40+'\n')
j=0
for tsr in batch:
print(f'\t\t 下面是 batch[{i}]的tensor[{j}] 中的数据:\n\t\t '+'-'*40)
print(f'\t\t tensor[{j}] = ',tsr)
print(f'\t\t type(tensor[{j}]) = ',type(tsr))
print(f'\t\t len(tensor[{j}]) = ',len(tsr))
if j==0:
print(f'\t\t 分析:batch[{i}]的tensor[{j}]包含64张图片的颜色值')
print('\t\t '+'-'*40+'\n')
print('\t\t\t 下面是前3张图片的数据:\n\t\t\t '+'-'*25)
k=0
for img in tsr:
print(f'\t\t\t 第{k}张图片 = ',img)
k+=1
if k==3:
break
print('\t\t\t '+'-'*25+'\n')
else:
print(f'\t\t 分析:batch[{i}]的tensor[{j}]包含64个label(类号)')
print('\t\t '+'-'*40+'\n')
j+=1
#break
i+=1
if i==b:
break
运行结果:
运行结果显示的就是 train_dataloader 的数据结构
(一)分析 train_dataloader 的数据结构
###########################################################
type(train_dataloader) = <class 'torch.utils.data.dataloader.DataLoader'>
len(train_dataloader) = 938
分析:
train_dataloader 是一个可迭代对象,转换成列表后发现:
每个元素是 1 batch(批)数据,共包含有 938 batch(批)数据。
下面是train_dataloader的前 2 个元素:
list(train_dataloader)[:2] = <数据很大省略了>
下面分析前 2 batch 数据:
==============================================================
下面是 batch[0] 的数据:
----------------------------------------
batch[0] = <数据很大省略了>
type(batch[0]) = <class 'list'>
len(batch[0]) = 2
分析:batch[0] 是一个列表,包含两个tensor。
----------------------------------------
下面是 batch[0]的tensor[0] 中的数据:
----------------------------------------
tensor[0] = <数据很大省略了>
type(tensor[0]) = <class 'torch.Tensor'>
len(tensor[0]) = 64
分析:batch[0]的tensor[0]包含64张图片的颜色值
----------------------------------------
下面是前3张图片的数据:
-------------------------
第0张图片 = <数据很大省略了>
第1张图片 = <数据很大省略了>
第2张图片 = <数据很大省略了>
-------------------------
下面是 batch[0]的tensor[1] 中的数据:
----------------------------------------
tensor[1] = tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5, 0, 9, 5, 5, 7, 9, 1, 0, 6, 4, 3, 1, 4, 8,
4, 3, 0, 2, 4, 4, 5, 3, 6, 6, 0, 8, 5, 2, 1, 6, 6, 7, 9, 5, 9, 2, 7, 3,
0, 3, 3, 3, 7, 2, 2, 6, 6, 8, 3, 3, 5, 0, 5, 5])
type(tensor[1]) = <class 'torch.Tensor'>
len(tensor[1]) = 64
分析:batch[0]的tensor[1]包含64个label(类号)
----------------------------------------
下面是 batch[1] 的数据:
----------------------------------------
batch[1] = <数据很大省略了>
type(batch[1]) = <class 'list'>
len(batch[1]) = 2
分析:batch[1] 是一个列表,包含两个tensor。
----------------------------------------
下面是 batch[1]的tensor[0] 中的数据:
----------------------------------------
tensor[0] = <数据很大省略了>
type(tensor[0]) = <class 'torch.Tensor'>
len(tensor[0]) = 64
分析:batch[1]的tensor[0]包含64张图片的颜色值
----------------------------------------
下面是前3张图片的数据:
-------------------------
第0张图片 = <数据很大省略了>
第1张图片 = <数据很大省略了>
第2张图片 = <数据很大省略了>
-------------------------
下面是 batch[1]的tensor[1] 中的数据:
----------------------------------------
tensor[1] = tensor([0, 2, 0, 0, 4, 1, 3, 1, 6, 3, 1, 4, 4, 6, 1, 9, 1, 3, 5, 7, 9, 7, 1, 7,
9, 9, 9, 3, 2, 9, 3, 6, 4, 1, 1, 8, 8, 0, 1, 1, 6, 8, 1, 9, 7, 8, 8, 9,
6, 6, 3, 1, 5, 4, 6, 7, 5, 5, 9, 2, 2, 2, 7, 6])
type(tensor[1]) = <class 'torch.Tensor'>
len(tensor[1]) = 64
分析:batch[1]的tensor[1]包含64个label(类号)
----------------------------------------
以上是运行代码得到的数据结构,其中“<数据很大省略了>”部分,在线下运行代码后,是可以再现的。测得的数据都是来自官方数据集的真实数据。
例图:
结论:
通过加载了真实数据集并测试分析,数据加载器 train_dataloader 的数据结构如下:
1.train_dataloader 共包含 938 批(batch)数据,是一个可迭代结构,可转换成列表。
2.每批数据是一个列表,包含两个tensor。其中第一个tensor包含64张图片的颜色值;第二个tensor包含64个label,就是64张图片的对应类号(整个数据集包含10大类图片,类号为0-9)。
3.一张图片是一个3维列表,其中的一个2维部分是一个通道,包含有28*28个点的颜色值。正常的图片有三个通道的数据,但是测试图片是黑白图片,只使用了一个通道。
参考:
https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html