import torch
from torch.utils import data
import os
import numpy as np
from PIL import Image
#怎么制作数据集
class dataset(data.Dataset):
def __init__(self,path):
self.path = path
self.dataset =[] #当数据是较大的图片时,一次性不要全部加载进数据
self.dataset.extend(os.listdir(path))#路径 路径里包含信息
os.listdir()
def __len__(self):
return len(self.dataset)
def __getitem__(self,index):
lable=torch.Tensor([int(self.dataset[index][0])])#取出标签 通过numpy转tensor不容易出错
img_path=os.path.join(self.path,self.dataset[index])
img=Image.open(img_path)
img_data=torch.Tensor(np.array((img))/255-0.5)#/255归1化 -0.5去均值化 后转成Tensor
return img_data,lable
#验证一下数据
if __name__=='__main__':
#1看取出的数据是否有问题
train_dataset = dataset('D:\workFile\深度学习_神经网络\img')
x=train_dataset[1][0]
y=train_dataset[1][1]
#举证转回图像,看看是否有问题
x2 = train_dataset[0][0].numpy()
y2 = train_dataset[1][1].numpy()
img_data=np.array((x2+0.5)*255,dtype=np.int8)
img = Image.fromarray(img_data,"RGB")
img.show()
02深度学习——数据集制作
最新推荐文章于 2024-08-16 21:12:04 发布