Dataset函数
Dataset函数:提供一种方法去获取数据及其label
具体功能:
1、如何获取每一个数据及其label
2、告诉我们总共有多少的数据
dataset在程序中起到的作用是告诉程序数据在哪,每个索引所对应的数据是什么。相当于一系列的存储单元,每个单元都存储了数据。
代码实战
from torch.utils.data import Dataset
from PIL import Image
import os
#创建一个类并继承Dataset
class MyData(Dataset):
#创建一个初始化方法为该类提供全局变量
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
#os.path.join拼接地址
self.path = os.path.join(self.root_dir, self.label_dir)
#列表存储地址文件夹中所有图片的名字
self.img_path = os.listdir(self.path)
#创建一个获取每一张图片地址和标签的方法
def __getitem__(self, idx):
#获取idx索引中图片的名字
img_name = self.img_path[idx]
#拼接该图片的完整路径
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
#读取图片
img = Image.open(img_item_path)
label = self.label_dir
return img,label
#计算该列表的长度
def __len__(self):
return len(self.img_path)
#图片文件件的地址
root_dir = "train"
bees_label_dir = "bees"
#创建一个对象存储蜜蜂对应文件夹的信息
ants_dataset = MyData(root_dir, ants_label_dir)
#获取列表中第一张图片的信息
img, label = bees_dataset[1]
#查看这张图片
print(img.show())
一般在处理数据时,一张图片对应的标签通常存储在另一个对应的文件中对应的txt文件中。这里就将图片存在bees_image文件夹中,然后创建一个新的bees_label文件夹,然后通过以下小程序为每张图片创建存储对应标签的txt文件。
import os
root_dir = "train"
target_path = "bees_image"
#列表存储地址文件夹中所有图片的名字
img_path = os.listdir(os.path.join(root_dir, target_path))
#读取目标文件夹的类型名(ants/bees)
label = target_path.split('_')[0]
#拼接输出目录
out_dir = label + "_label"
for i in img_path:
#获取每张图片的名字
file_name = i.split('.jpg')[0]
#在对应的文件夹中的对应文件中写入该图片对应的txt的label
with open(os.path.join(root_dir, out_dir, "{}.txt".format(file_name)), 'w') as f:
f.write(label)
Dataloader函数
Dataloader函数:Dataloader是一个装载数据集的一个工具,从dataset中取数据
代码实战
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
#准备的测试数据集
test_data = torchvision.datasets.CIFAR10(root = "./dataset", train=False, transform=torchvision.transforms.ToTensor())
#dataset:告诉程序数据集的位置
#batch_size:每一批加载多少数据
#shuffle:每一批数据是否乱序
#num_workers:加载数据是单进程还是多进程,默认0,采用主进程加载
#drop_last:若批加载后有余是否舍弃
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
获取数据,并查看数据格式及标签
img, target = test_data[0]
print(img.shape)
print(target)
结果:
torch.Size([3, 32, 32])
3
即图片为RGB三通道,彩色图片,像素大小为32*32,tag为3
dataset和dataloader取数据对比
dataset | dataloader |
---|---|
getitem() | dataloader(batch_size=4) |
return img,target | return imgs,targets |
dataloader返回的img0,target0=dataset[0] img1,target1=dataset[1] img2,target2=dataset[2] img3,target3=dataset[3],分别将img0,1,2,3和target0,1,2,3打包成imgs和targets
接着看打包的数据信息:
for data in test_loader:
imgs, targets = data
print(imgs.shape)
print(targets)
通过tensorboard显示抓取结果
writer = SummaryWriter("logs")
step = 0
for data in test_loader:
imgs, targets = data
# print(imgs.shape)
# print(targets)
writer.add_images("test_data", imgs, step)
step = step + 1
writer.close()
tensorboard --logdir=logs --port=6007
因为drop_last设置的为False,所以最后一组图片个数不足64时仍然保留