实验:分析 Pytorch 框架中 train_dataloader 的数据结构

前言:

今天我继续学习Pytorch框架的部分功能,分析了 数据加载器 train_dataloader 的数据结构。下面与大家一起分享分析过程,期望能给您带来一些帮助。

思路:

根据Pytorch官网上的tutorials部分介绍,首先下载训练和测试数据集并安装好,然后是编程分析数据加载器的数据结构。Pytorch框架共有两个数据加载器,这里分享的是第一个加载器 train_dataloader 的数据结构,第二个加载器 test_dataloader 的数据结构与前者的是一样的,不再赘述。

代码:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
from torchvision.transforms import ToTensor

from torchvision.utils import save_image

#下面的url讲了如何下载安装训练和测试数据集
#https://pytorch.org/vision/stable/generated/torchvision.datasets.FashionMNIST.html
# root (string) – Root directory of dataset
# where FashionMNIST/raw/train-images-idx3-ubyte
# and FashionMNIST/raw/t10k-images-idx3-ubyte exist.

training_data = datasets.FashionMNIST(
    root="./data",
    train=True,
    download=False, #True,
    transform=ToTensor(),
)

test_data = datasets.FashionMNIST(
    root="./data",
    train=False,
    download=False, #True,
    transform=ToTensor(),
)

#一批一批地处理。一批图片的数量
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)


print('\n\n\n(一)分析 train_dataloader 的数据结构')
print('###########################################################')

print('type(train_dataloader) = ',type(train_dataloader))
print('len(train_dataloader) = ',len(train_dataloader))
print(f'''分析:
train_dataloader 是一个可迭代对象,转换成列表后发现:
每个元素是 1 batch(批)数据,共包含有 {len(train_dataloader)} batch(批)数据。''')

b=2
print(f'\n下面是train_dataloader的前 {b} 个元素:')
print(f'list(train_dataloader)[:{b}] = ',list(train_dataloader)[:b])
print(f'\n下面分析前 {b} batch 数据:')
print('==============================================================\n')


#分批分析
i=0
for batch in train_dataloader:
    print(f'\t 下面是 batch[{i}] 的数据:\n\t '+'-'*40)
    print(f'\t batch[{i}] = ',batch)
    print(f'\t type(batch[{i}]) = ',type(batch))
    print(f'\t len(batch[{i}]) = ',len(batch))
    print(f'\t 分析:batch[{i}] 是一个列表,包含两个tensor。')
    print('\t '+'-'*40+'\n')
    
    j=0
    for tsr in batch:
        print(f'\t\t 下面是 batch[{i}]的tensor[{j}] 中的数据:\n\t\t '+'-'*40)
        print(f'\t\t tensor[{j}] = ',tsr)
        print(f'\t\t type(tensor[{j}]) = ',type(tsr))
        print(f'\t\t len(tensor[{j}]) = ',len(tsr))

        if j==0:
            print(f'\t\t 分析:batch[{i}]的tensor[{j}]包含64张图片的颜色值')
            print('\t\t '+'-'*40+'\n')

            print('\t\t\t 下面是前3张图片的数据:\n\t\t\t '+'-'*25)
            k=0
            for img in tsr:
                print(f'\t\t\t 第{k}张图片 = ',img)
                k+=1
                if k==3:
                    break
            print('\t\t\t '+'-'*25+'\n')
            
        else:
            print(f'\t\t 分析:batch[{i}]的tensor[{j}]包含64个label(类号)')
            print('\t\t '+'-'*40+'\n')
            
        j+=1
        #break
        
    i+=1
    if i==b:
        break

运行结果:

运行结果显示的就是 train_dataloader 的数据结构

(一)分析 train_dataloader 的数据结构
###########################################################
type(train_dataloader) =  <class 'torch.utils.data.dataloader.DataLoader'>
len(train_dataloader) =  938
分析:
train_dataloader 是一个可迭代对象,转换成列表后发现:
每个元素是 1 batch(批)数据,共包含有 938 batch(批)数据。

下面是train_dataloader的前 2 个元素:
list(train_dataloader)[:2] =  <数据很大省略了>

下面分析前 2 batch 数据:
==============================================================

	 下面是 batch[0] 的数据:
	 ----------------------------------------
	 batch[0] =  <数据很大省略了>
	 type(batch[0]) =  <class 'list'>
	 len(batch[0]) =  2
	 分析:batch[0] 是一个列表,包含两个tensor。
	 ----------------------------------------

		 下面是 batch[0]的tensor[0] 中的数据:
		 ----------------------------------------
		 tensor[0] =  <数据很大省略了>
		 type(tensor[0]) =  <class 'torch.Tensor'>
		 len(tensor[0]) =  64
		 分析:batch[0]的tensor[0]包含64张图片的颜色值
		 ----------------------------------------

			 下面是前3张图片的数据:
			 -------------------------
			 第0张图片 =  <数据很大省略了>
			 第1张图片 =  <数据很大省略了>
			 第2张图片 =  <数据很大省略了>
			 -------------------------

		 下面是 batch[0]的tensor[1] 中的数据:
		 ----------------------------------------
		 tensor[1] =  tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5, 0, 9, 5, 5, 7, 9, 1, 0, 6, 4, 3, 1, 4, 8,
        4, 3, 0, 2, 4, 4, 5, 3, 6, 6, 0, 8, 5, 2, 1, 6, 6, 7, 9, 5, 9, 2, 7, 3,
        0, 3, 3, 3, 7, 2, 2, 6, 6, 8, 3, 3, 5, 0, 5, 5])
		 type(tensor[1]) =  <class 'torch.Tensor'>
		 len(tensor[1]) =  64
		 分析:batch[0]的tensor[1]包含64个label(类号)
		 ----------------------------------------

	 下面是 batch[1] 的数据:
	 ----------------------------------------
	 batch[1] =  <数据很大省略了>
	 type(batch[1]) =  <class 'list'>
	 len(batch[1]) =  2
	 分析:batch[1] 是一个列表,包含两个tensor。
	 ----------------------------------------

		 下面是 batch[1]的tensor[0] 中的数据:
		 ----------------------------------------
		 tensor[0] =  <数据很大省略了>
		 type(tensor[0]) =  <class 'torch.Tensor'>
		 len(tensor[0]) =  64
		 分析:batch[1]的tensor[0]包含64张图片的颜色值
		 ----------------------------------------

			 下面是前3张图片的数据:
			 -------------------------
			 第0张图片 =  <数据很大省略了>
			 第1张图片 =  <数据很大省略了>
			 第2张图片 =  <数据很大省略了>
			 -------------------------

		 下面是 batch[1]的tensor[1] 中的数据:
		 ----------------------------------------
		 tensor[1] =  tensor([0, 2, 0, 0, 4, 1, 3, 1, 6, 3, 1, 4, 4, 6, 1, 9, 1, 3, 5, 7, 9, 7, 1, 7,
        9, 9, 9, 3, 2, 9, 3, 6, 4, 1, 1, 8, 8, 0, 1, 1, 6, 8, 1, 9, 7, 8, 8, 9,
        6, 6, 3, 1, 5, 4, 6, 7, 5, 5, 9, 2, 2, 2, 7, 6])
		 type(tensor[1]) =  <class 'torch.Tensor'>
		 len(tensor[1]) =  64
		 分析:batch[1]的tensor[1]包含64个label(类号)
		 ----------------------------------------

以上是运行代码得到的数据结构,其中“<数据很大省略了>”部分,在线下运行代码后,是可以再现的。测得的数据都是来自官方数据集的真实数据。

例图:

结论:

通过加载了真实数据集并测试分析,数据加载器 train_dataloader 的数据结构如下:

1.train_dataloader 共包含 938 批(batch)数据,是一个可迭代结构,可转换成列表。

2.每批数据是一个列表,包含两个tensor。其中第一个tensor包含64张图片的颜色值;第二个tensor包含64个label,就是64张图片的对应类号(整个数据集包含10大类图片,类号为0-9)。

3.一张图片是一个3维列表,其中的一个2维部分是一个通道,包含有28*28个点的颜色值。正常的图片有三个通道的数据,但是测试图片是黑白图片,只使用了一个通道。

参考:

https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
 

  • 8
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值