pytorch读取数据集

Imagefolder(通用类型)

使用ImageFolder读取数据的前提是root目录下,文件夹名对应类名,每个文件夹下存储同一类别的图片。

root/dog/xxx.png

root/dog/xxy.png

root/cat/123.png

root/cat/nsdf3.png 

torchvision.datasets.ImageFolder(root="root folder path", [transform, target_transform])

参数:

  • root:在root指定的路径下寻找图片
  • transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
  • target_transform:对label的转换

有三个成员变量​​​​​​​

  • self.classes - 用一个list保存 类名
  • self.class_to_idx - 类名对应的 索引
  • self.imgs - 保存(img-path, class) tuple的list

在cifar100的train/类别/图片上:

from torchvision import transforms as T
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder


dataset = ImageFolder('train')
#print(dataset.classes)
#print(dataset.class_to_idx)
print(dataset.imgs)

参考:

继承torch.utils.data.Datasets(通用类型)

基本流程:

  1. 根据数据集产生csv文件。格式 “文件路径 标签”
  2. 继承torch.utils.data.Datasets,通过上述csv文件读取图片。
  3. 使用

1、process.py产生train.py 文件路径 标签

import os
import pandas as pd
Train_DIR = "train"
Test_DIR = "test"

def get_train_pic(path) -> dict:
    level_1 = os.listdir(path)
    count = 0
    base_dir = os.path.abspath('.')#当前绝对路径
    result = dict()
    for i in level_1:
        real_path = os.path.join(base_dir,path, i)
        print(real_path)
        pic_path = os.listdir(real_path)
        for i in pic_path:
            kk = os.path.join(real_path, i)
            # print(kk)
            # print(count)
            result[kk] = count
        count += 1
        #print(count)


    return result

train = get_train_pic(Train_DIR)
test = get_train_pic(Test_DIR)
# print(test)

data_train = pd.DataFrame.from_dict(train, orient='index')
data_test = pd.DataFrame.from_dict(test, orient='index')
# print(data)

data_train.to_csv("train.csv")
data_test.to_csv("test.csv")

2、

import torch
import numpy as np
from torchvision import datasets, transforms
import torch.utils.data as data
from PIL import Image
import random
import os
import cv2
import random

def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        raise(RuntimeError('No Module named accimage'))
    else:
        return pil_loader(path)

class ImageNetData(data.Dataset):
    def __init__(self, img_root, img_file, is_training=False, transform=None, target_transform=None, loader=default_loader):
        self.root = img_root

        self.imgs = []
        with open(img_file, 'r', encoding='utf-8') as fd:
            for i, _line in enumerate(fd.readlines()):
                infos = _line.strip().split()
                
                if 2 != len(infos) :                                       # Notice
                    continue
                if is_training:
                    real_path = os.path.join(self.root, 'train', infos[0])
                else:
                    real_path = os.path.join(self.root, 'test', infos[0])
                class_id  = int(infos[-1])
                self.imgs.append((real_path, class_id))
        
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        path, class_id = self.imgs[index]

        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            class_id = torch.LongTensor([class_id])

        return img, class_id
    
    def __len__(self):
        return len(self.imgs)

 3、

train_datasets = ImageNetData(args.data_root,args.train_file,True,data_transforms['train'])
    val_datasets = ImageNetData(args.data_root,args.val_file,False,data_transforms['val'])

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值