Pytorch-自定义数据集

自定义数据集

Pytorch将数据集的处理过程标准化

数据加载的具体过程:

  1. 继承Dataset类
    Pytorch中提供了torch.utils.data.Dataset抽象类,使用时需要继承这个类,并重写__len__和__geiitem__函数。
  2. 增加数据变换
    Pytorch提供了torchvision.transforms可以比较方便进行图像的缩放、裁剪、随机旋转、填充及张量的归一化操作等,操作对象是PIL的Image或者Tensor。可以使用transforms.Compose将多个变换整合。使用的时候一般集成到Dataset的继承类中。
  3. 继承DataLoader
    需要进行批量处理、随机选取等等,所以还需要这一步。

代码

import argparse
import os
import glob
import csv
import PIL
import visdom
import matplotlib.pyplot as plt
import torch
import time
from torch.utils.data import Dataset
from torchvision import transforms

class MyData(Dataset):
    def __init__(self,root,transform=None):
        super(MyData,self).__init__()
        self.root=root
        self.transform=transform
        self.name2label={}  # 映射
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue
            self.name2label[name]=len(self.name2label.keys())
            #{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
        print(self.name2label)

        # image+label
        self.images,self.labels=self.load_csv('Image2Label.csv')


    def load_csv(self,fliename):
        if not os.path.exists(os.path.join(self.root,fliename)):
            images=[]
            for name in self.name2label.keys():
                images+=glob.glob(os.path.join(self.root,name,'*.jpg'))

            print(len(images),images)
            with open(os.path.join(self.root,fliename),mode='w',newline='') as f:
                writer=csv.writer(f)
                for img in images:#'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg'
                    name=img.split(os.sep)[-2]
                    label=self.name2label[name]
                    # 'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg', '0'
                    writer.writerow([img,label])
                print('writen into csv file:',fliename)

        images,labels=[],[]
        with open(os.path.join(self.root,fliename)) as f:
            reader=csv.reader(f)
            for row in reader:
                # 'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg', '0'
                img,label=row
                label=int(label)

                images.append(img)
                labels.append(label)

        assert len(images)==len(labels)
        return images,labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        #idx[0,len(images)]
        #img:'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg'
        #label:0
        input_image,input_label=self.images[idx],self.labels[idx]
        #路径->图像数据类型
        input_image=PIL.Image.open(input_image).convert('RGB')
        if self.transform:
            input_image=self.transform(input_image)
        return input_image,input_label

def main():
    parser = argparse.ArgumentParser(description='训练参数')
    parser.add_argument('--batchsize', type=int, default=20, help='The number of batch_size')
    parser.add_argument('--epochs', type=int, default=20, help='The number of epochs')
    args = parser.parse_args()

    viz=visdom.Visdom() #将一个窗口类实例化

    image_path = r'D:\Projects\DeepLearning\Dataset\flower_photos\train'
    tf=transforms.Compose([
                            transforms.Resize((224,224)),
                            transforms.ToTensor()
                          ])
    sample=MyData(image_path,tf) #sample=MyData(image_path,None)
    #x,y=next(iter(sample))
    #viz.image(x,win='sample_x',opts=dict(title='sample_x'))
    train_loader=torch.utils.data.DataLoader(sample,batch_size=args.batchsize,shuffle=True)

    #必须将图片大小提前调整为一样才可以显示
    for x,y in train_loader:
        viz.images(x,nrow=5,win='batch',opts=dict(title='batch'))
        viz.text(str(y),win='label',opts=dict(title='haha'))
        time.sleep(10)


if __name__ =='__main__':
    main()
python -m visdom.server

pytorch可视化工具visdom启动失败解决方法
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值