1、PyTorch中加载数据
1、主要涉及到的两个类:Dataset、Dataloader
实际中,有许多垃圾(数据),
Dataset 可以将某一类垃圾提取出来(如可回收垃圾),并编上号,同时也获取了每个数据的label。简言之,Dataset提供了一种方式去获取数据及其label。
主要实现的两个功能:
(1)如何获取每一个数据及其label;
(2)总共有多少个数据。
Dataloader 是将Dataset中的数据进行打包,为后面的网络提供不同的数据形式。以batch_size向网络传数据。
2、实战
1、读取图片:
from PIL import Image
# 记录图片的地址
img_path = 'hymenoptera_data/train/ants/0013035.jpg'
# 读取图片
img = Image.open(img_path)
# 显示这个图片
img.show()
2、重写类:
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset) # 继承Dataset类
def __init__(self, root_dir, label_dir):
# 为整个class提供全局变量
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir) # 获得路径
self.img_path = os.listdir(self.path) # 获得地址
def __getitem__(self, idx):
# 获取其中的每一个
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "hymenoptera_data/train"
ants_label_dir = 'ants'
bees_label_dir = 'bees'
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)
train_dataset = ants_dataset + bees_dataset
3、课后练习(img,label分开):
import os
root_dir = 'hymenoptera_data/train'
target_dir1 = 'ants_image'
target_dir2 = 'bees_image'
img_path1 = os.listdir(os.path.join(root_dir, target_dir1))
img_path2 = os.listdir(os.path.join(root_dir, target_dir2))
label1 = target_dir1.split('_')[0] # ants
label2 = target_dir2.split('_')[0] # bees
out_dir1 = 'ants_label'
out_dir2 = 'bees_label'
for i in img_path1:
file_name = i.split('.jpg')[0]
with open(os.path.join(root_dir, out_dir1,"{}.txt".format(file_name)),'w') as f:
f.write(label1)
for i in img_path2:
file_name = i.split('.jpg')[0]
with open(os.path.join(root_dir, out_dir2,"{}.txt".format(file_name)),'w') as f:
f.write(label2)