pytorch加载数据类(Dataset类,Dataloader类)

Dataset函数

Dataset函数:提供一种方法去获取数据及其label
具体功能:
1、如何获取每一个数据及其label
2、告诉我们总共有多少的数据

dataset在程序中起到的作用是告诉程序数据在哪,每个索引所对应的数据是什么。相当于一系列的存储单元,每个单元都存储了数据。

代码实战
from torch.utils.data import Dataset
from PIL import Image
import os

#创建一个类并继承Dataset
class MyData(Dataset):

		#创建一个初始化方法为该类提供全局变量
	    def __init__(self, root_dir, label_dir):
        	self.root_dir = root_dir
        	self.label_dir = label_dir
        	#os.path.join拼接地址
        	self.path = os.path.join(self.root_dir, self.label_dir)
        	#列表存储地址文件夹中所有图片的名字
        	self.img_path = os.listdir(self.path)
        	
  		#创建一个获取每一张图片地址和标签的方法
   		def __getitem__(self, idx):
   			#获取idx索引中图片的名字
   			img_name = self.img_path[idx]
   			#拼接该图片的完整路径
        	img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        	#读取图片
        	img = Image.open(img_item_path)
        	label = self.label_dir
        	return img,label
        	
		#计算该列表的长度
		def __len__(self):
    		return len(self.img_path)

#图片文件件的地址
root_dir = "train"
bees_label_dir = "bees"

#创建一个对象存储蜜蜂对应文件夹的信息
ants_dataset = MyData(root_dir, ants_label_dir)

#获取列表中第一张图片的信息
img, label = bees_dataset[1]

#查看这张图片
print(img.show())

一般在处理数据时,一张图片对应的标签通常存储在另一个对应的文件中对应的txt文件中。这里就将图片存在bees_image文件夹中,然后创建一个新的bees_label文件夹,然后通过以下小程序为每张图片创建存储对应标签的txt文件。

import os

root_dir = "train"
target_path = "bees_image"

#列表存储地址文件夹中所有图片的名字
img_path = os.listdir(os.path.join(root_dir, target_path))

#读取目标文件夹的类型名(ants/bees)
label = target_path.split('_')[0]

#拼接输出目录
out_dir = label + "_label"

for i in img_path:
    #获取每张图片的名字
    file_name = i.split('.jpg')[0]
    #在对应的文件夹中的对应文件中写入该图片对应的txt的label
    with open(os.path.join(root_dir, out_dir, "{}.txt".format(file_name)), 'w') as f:
        f.write(label)

Dataloader函数

Dataloader函数:Dataloader是一个装载数据集的一个工具,从dataset中取数据

代码实战
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

#准备的测试数据集
test_data = torchvision.datasets.CIFAR10(root = "./dataset", train=False, transform=torchvision.transforms.ToTensor())

#dataset:告诉程序数据集的位置
#batch_size:每一批加载多少数据
#shuffle:每一批数据是否乱序
#num_workers:加载数据是单进程还是多进程,默认0,采用主进程加载
#drop_last:若批加载后有余是否舍弃
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)

获取数据,并查看数据格式及标签

img, target = test_data[0]
print(img.shape)
print(target)

结果:

torch.Size([3, 32, 32])
3

即图片为RGB三通道,彩色图片,像素大小为32*32,tag为3

dataset和dataloader取数据对比
datasetdataloader
getitem()dataloader(batch_size=4)
return img,targetreturn imgs,targets

dataloader返回的img0,target0=dataset[0] img1,target1=dataset[1] img2,target2=dataset[2] img3,target3=dataset[3],分别将img0,1,2,3和target0,1,2,3打包成imgs和targets
接着看打包的数据信息:

for data in test_loader:
    imgs, targets = data
    print(imgs.shape)
    print(targets)


通过tensorboard显示抓取结果

writer = SummaryWriter("logs")
step = 0
for data in test_loader:
    imgs, targets = data
    # print(imgs.shape)
    # print(targets)
    writer.add_images("test_data", imgs, step)
    step = step + 1

writer.close()
tensorboard --logdir=logs --port=6007

因为drop_last设置的为False,所以最后一组图片个数不足64时仍然保留

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值