编程基础很弱,需要机器学习,学习记录,按自己理解写的,希望以后能学懂吧,要是有大神看到还请赐教。
Pytorch加载数据初认识
读取数据两个类:
Datasets:提取某些数据并编号,并获取对应Label。可知总共有多少数据。
DataLoader:数据打包,为网络提取不同数据形式。
如何使用Datasets类:
# 从torch.utils.data仓库拿出Dataset工具
from torch.utils.data import Dataset
代码实战:
#Pycharm console
#使用好处:可以看清内部结构。
from PIL import Image
img_path=""
img=Image.open(img_path)
img.size
#从右侧可以看到在变量img内部结构,发现有size选项。
#由此,知道仓库内的东西。(抽象说法,类比ABAQUS)
这里相对路径没整明白,改成拼接式。
# dataset类
from torch.utils.data import Dataset
from PIL import Image
import os
#继承类还是什么?
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir # 相对地址
self.label_dir = label_dir # 相对地址内的文件夹名
self.path = os.path.join(self.root_dir, self.label_dir) # 此函数将以上两地址拼接
self.img_path = os.listdir(self.path) # 将地址内取列表?
def __getitem__(self, idx):
img_name = self.img_path[idx] # 根据索引获取文件name
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 拼接为文件路径
img = Image.open(img_item_path) # 使用Image打开文件路径,此时变量img为图片
label = self.label_dir # 获取标签
return img, label # 返回图片与图片对应的标签
def __len__(self):
return len(self.img_path) # 记录img_path列表长度,即文件个数
#由于视频中相对地址没懂,此处使用绝对地址来拼接
root_dir = "D:\\Scientific_Research\\Pycharmproject\\Biji\\dataset\\train"
ants_label_dir = "ants"
bees_label_dir = "bees"
#Mydata()需要两个输入,分别为root_dir,label_dir
ants_dataset = MyData(root_dir=root_dir, label_dir=ants_label_dir)
bees_dataset = MyData(root_dir=root_dir, label_dir=bees_label_dir)
#通过索引访问将以上两个列表内容,返回元组
ants_dataset[0]#此为一个元组,(图片,标签)
ants_dataset[0][0].show()#显示图片
train_dataset = ants_dataset + bees_dataset#将蚂蚁、蜜蜂数据集相加,即用于数据集拼接