Pytorch如何加载自己的数据集(复写Dataset类)

import torch
import os
import glob
import random
import csv

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

#加载自己的数据集
class Pokemon(Dataset):
    #定义自己的主函数,函数内变量名:根目录,图像的规模,以及模式(训练,验证或者测试)
    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()
        #创建字典
        self.name2label = {}
        #连接两个或更多的路径名组件,如果有一个组件是一个绝对路径,则它之前的所有组件均会被舍弃
        #os.listdir返回指定的文件夹或文件的名字的列表
        for name in sorted(os.listdir((os.path.join(root)))):
            if not os.path.isdir(os.path.join(root, name)):
                continue
            #将文件夹中是文件的文件名进行分类,按照字典中的键值对应规则,具体实现通过原先字典的长度,一开始为0,逐渐+1
            self.name2label[name] = len(self.name2label.keys())
        print(self.name2label)

        self.root=root
        self.resize=resize
        self.load_csv('images.csv')

        self.images, self.labels = self.load_csv('images.csv')

        if mode == 'train':
            self.images=self.images[:int(0.6*len(self.images))]
            self.labels=self.labels[:int(0.6*len(self.labels))]

        if mode=='val':
            self.images=self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels=self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]

        if mode=='test':
            self.images=self.images[int(0.8*len(self.images)):int(len(self.images))]
            self.labels=self.labels[int(0.8*len(self.labels)):int(len(self.labels))]
    #编写load_csv函数,filename是images.csv
    def load_csv(self, filename):
        #创建images列表,保存个图片的路径
        images = []
        for name in self.name2label.keys():
            #glob.glob查找符合特定规则的文件路径名
            images+=glob.glob(os.path.join(self.root,name,'*.png'))
            images+=glob.glob(os.path.join(self.root,name,'*.jpg'))
            images+=glob.glob(os.path.join(self.root,name,'*.jpeg'))

        random.shuffle(images)
        # print(len(images),images[0])
        #with open用来打开本地文件 CSV.WRITER写一个csv文件,
        with open(os.path.join(self.root, filename), mode='w', newline='') as f:
            writer= csv.writer(f)
            for img in images:
                name=img.split(os.sep)[-2]
                label=self.name2label[name]
                writer.writerow([img, label])
            #print('writen in to filename', filename)
        images, labels =[], []
        with open(os.path.join(self.root, filename)) as f:
             reader= csv.reader(f)
             for row in reader:
                 img, label=row
                 label=int(label)
                 images.append(img)
                 labels.append(label)
        assert len(images) == len(labels)
        return images, labels

    def __len__(self):
        return len(self.images)

    #图片格式的转化
    def denomalize (self, x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229,0.224,0.225]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
        x = x_hat * std + mean
        return x


    def __getitem__(self, idx):
        img,label=self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),    #将路径名转化为图片数据类型
            transforms.Resize((self.resize, self.resize)),
            transforms.RandomRotation(0),
            #以中心点按照原来的图片大小进行裁剪,操作过后,图片的大小跟原来的相同
            #transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            #希望图片数据转换到0  1之间,但是图片转换之后会产生偏差,座椅还需要denormanization
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229,0.224,0.225])
        ])
        #转换为图片类型后再转换为张量
        img = tf(img)
        label = torch.tensor(label)
        return img, label


def main():

     import time
     import visdom
     viz=visdom.Visdom()
     db=Pokemon('D://face recogniton',224,'train')   #实例化一个对象 人=>具体的一个人
      #x,y = next(iter(db))
      #print('sample',x.shape,y.shape)
      #print(x,y)
     #im = Image.open("D://360MoveData//Users//Gentle//Desktop//0.jpg")  ##文件存在的路径
     # # im.show()
     # viz.image(x, win='sample_x',opts=dict(title='sample_x'))
     # #viz.image(db.denomalize(y), win='sample_y',opts=dict(title='sample_y'))
     loader = DataLoader(db, batch_size=32, shuffle=True)

     for x,y in loader:
          viz.images(db.denomalize(x), nrow=8, win='batch', opts=dict(title='batch'))
          viz.text(str(y.numpy()), win='y-bay=tch', opts=dict(title='batch'))
          time.sleep(5)

```python
最后通过main函数试运行,终端输入 python-m visdom.server

if name == ‘main’:
main()


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Begin,again

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值