手动创建dataset类
import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt
首先需要使用glob得到需要传入数据的路径(path)
import glob
all_img_path = glob.glob(r'C:\Users\Doctor Jin\dataset2\*.jpg') #获取路径 *代表匹配任意路径的名称 一定要加r防止转译
其次创建其对应标签labels
species = ['cloudy','rain','shine','sunrise']
用字典给标签创建index
species_to_index = dict((c,i)for i,c in enumerate(species))
idx_to_species = dict((v, k) for k, v in species_to_idx.items())
给图片创建标签
all_labels = [0]
for img in all_img_path:
for i,c in enumerate(species): # 每次只循环4次 i从0-3
if c in img:
all_labels.append(i)
transform转化
transform = transforms.Compose([
transforms.Resize((96, 96)),
transforms.ToTensor(),
])
创建一个dataset类的方法,需要继承自torch.utils.data.Dataset类,并且必须创建__getitem__和__len__方法
class Mydataset(data.Dataset):
def__init__(self,img_paths,labels,transform):
self.img = img_paths
self.labels = labels
self.transforms = transform
def__getitem__(self,index):
img = self.imgs[index]
label = self.imgs[index]
pil_img = Image.open(img) #用pil是因为前面只是设置了路径,并没有把图片的信息导入进来,如果不用pil仅仅是针对路径创建了dataset
pil_img = pil_img.convert('RGB')
data = self.transforms(pil_img)
return data,label
del__len__(self):
return len(self.img)
wheather_dataset = Mydataset(all_imgs_path,all_labels,transform)
wheather_dl = data.DataLoader(wheather_dataset,batch_size=16,shuttle=True)
设置train数据和test数据
index = np.random.permutation(len(all_imgs_path)) #创造一个乱序的index
all_imgs_path = np.array(all_imgs_path)[index] #利用numpy对照片进行乱序
s = int(len(all_imgs_path)*0.8) #设置train数据的个数
train_imgs = all_imgs_path[:s]
train_labels = all_labels[:s]
test_imgs = all_imgs_path[s:]
test_labels = all_labels[s:]
train_ds = Mydataset(train_imgs,train_labels,transform)
test_ds = Mydata(test_imgs,test_labels,transform)
train_dl = data.DataLoader(train_ds,batch_size=16,shuffle=True)
test_dl = data.DataLoader(test_ds,batch_size=16)
批量化调整dataset
class new_dataset(data.Dataset):
def __init__(self,some_dataset):
self.dataset = some_dataset
def __getitem__(self,index): #getitem方法中实现
img,label = self.dataset[index]
img = img.permute(1,2,0)
return img,label
def __len__(self):
return len()