手动创建dataset类

手动创建dataset类

import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms
import matplotlib.pyplot as plt

首先需要使用glob得到需要传入数据的路径(path)

import glob
all_img_path = glob.glob(r'C:\Users\Doctor Jin\dataset2\*.jpg') #获取路径  *代表匹配任意路径的名称 一定要加r防止转译

其次创建其对应标签labels

species = ['cloudy','rain','shine','sunrise']

用字典给标签创建index

species_to_index = dict((c,i)for i,c in enumerate(species))
idx_to_species = dict((v, k) for k, v in species_to_idx.items())

给图片创建标签

all_labels = [0]
for img in all_img_path:
    for i,c in enumerate(species): # 每次只循环4次 i从0-3
        if c in img:
            all_labels.append(i)

transform转化

transform = transforms.Compose([
                    transforms.Resize((96, 96)),
                    transforms.ToTensor(),
])

创建一个dataset类的方法,需要继承自torch.utils.data.Dataset类,并且必须创建__getitem__和__len__方法

class Mydataset(data.Dataset):
    def__init__(self,img_paths,labels,transform):
        self.img = img_paths
        self.labels = labels
        self.transforms = transform
    def__getitem__(self,index):
        img = self.imgs[index]
        label = self.imgs[index]
        pil_img = Image.open(img)  #用pil是因为前面只是设置了路径,并没有把图片的信息导入进来,如果不用pil仅仅是针对路径创建了dataset
        pil_img = pil_img.convert('RGB')
        data = self.transforms(pil_img)
        return data,label
    del__len__(self):
        return len(self.img)
wheather_dataset = Mydataset(all_imgs_path,all_labels,transform)
wheather_dl = data.DataLoader(wheather_dataset,batch_size=16,shuttle=True)

设置train数据和test数据

index = np.random.permutation(len(all_imgs_path)) #创造一个乱序的index
all_imgs_path = np.array(all_imgs_path)[index] #利用numpy对照片进行乱序
s = int(len(all_imgs_path)*0.8) #设置train数据的个数
train_imgs = all_imgs_path[:s]
train_labels = all_labels[:s]
test_imgs = all_imgs_path[s:]
test_labels = all_labels[s:]
train_ds = Mydataset(train_imgs,train_labels,transform)
test_ds = Mydata(test_imgs,test_labels,transform)
train_dl = data.DataLoader(train_ds,batch_size=16,shuffle=True)
test_dl = data.DataLoader(test_ds,batch_size=16)

批量化调整dataset

class new_dataset(data.Dataset):
    def __init__(self,some_dataset):
        self.dataset = some_dataset
    def __getitem__(self,index): #getitem方法中实现
        img,label = self.dataset[index]
        img = img.permute(1,2,0)
        return img,label
    def __len__(self):
        return len()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值