深度学习图片数据读入

1.有N种类别的图片分别放在N个文件夹里,每个文件夹放一类标签

  (1)首先将图片划分为训练集和验证集

import os
from shutil import copy, rmtree
import random


def mk_dir(file_path: str):
    # 如果创建的文件夹存在递归删除整个文件夹下所有文件,包括此文件夹
    if os.path.exists(file_path):
        rmtree(file_path)
    os.makedirs(file_path)

def main(data_dir, new_dir):
    random.seed(0)

    # 以训练集:测试集 = 7:3 的比例划分
    train_rate = 0.7
    
    # 存放所有类图片的地址
    
    assert os.path.exists(data_dir), "该地址不存在"
    
    # 获取所有的类别名称
    class_list = os.listdir(data_dir)

    # 创建train和val文件夹,每个文件夹里再创建所有类别的文件
    # 这里为了保留原来的数据集因此新建了一个文件存放分好的图片
    mk_dir(new_dir)
    
    train_dir = os.path.join(new_dir, "train")
    val_dir = os.path.join(new_dir, "val")
    mk_dir(train_dir)

    for cls in class_list:
        # 图片新的存放地址
        train_path = os.path.join(train_dir, cls)
        mk_dir(train_path)
        val_path = os.path.join(val_dir, cls)
        mk_dir(val_path)
        
        # 当前处理类的文件夹
        cls_path = os.path.join(data_dir, cls)

        # 获取某一个类别图片的所有图片名
        all_img = os.listdir(cls_path)
        num_img = len(all_img)
        img_idx = list(range(0, num_img))

        # 随机打乱索引
        random.shuffle(img_idx)
        num_train = int(num_img * train_rate)

        # 遍历所有当前处理类的所有图片
        for idx in range(num_img):
            if idx < num_train:
                copy(os.path.join(cls_path, all_img[img_idx[idx]]), train_path)
            else:
                copy(os.path.join(cls_path, all_img[img_idx[idx]]), val_path)

            print("\r正在处理{}类:{}/{}".format(cls, idx+1, num_img), end="")

    print("\n训练集,验证集划分完成!!!")

if __name__ == '__main__':
    data_dir = './afterResizeVehcileClassificationDataset/'
    new_dir = './newvehcileClassificationDataset/'
    main(data_dir, new_dir)

        划分完成后数据文件变为:

(2)读取数据

import torch
from torchvision import transforms, datasets, utils
import json

data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

img_path = './newvehcileClassificationDataset'

# 加载训练集
train_dataset = datasets.ImageFolder(root=img_path + '/train',transform=data_transform)

cls_list = train_dataset.class_to_idx 
cls_dict = dict((val, key) for key, val in cls_list.items())

# 将python对象编码成Json字符串
json_str = json.dumps(cls_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

# 加载验证集
val_dataset = datasets.ImageFolder(root=img_path + '/val', transform=data_transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

2.所有的图片都在一个文件内类别名保存在图片名中

        (1)自定义数据类型

from PIL import Image
from torch.utils.data import Dataset
import os

class MyDataset(Dataset):
    
    def __init__(self,img_path,transform):
        super(MyDataset).__init__()
        self.img_path = img_path
        
        # 获取文件夹下的所有图片名称
        self.all_names = [x for x in os.listdir(img_path) if x.endswith(".jpg")]
        self.transform = transform
  
    def __len__(self):
        return len(self.all_names)
  
    def __getitem__(self,idx):
        img_name = self.all_names[idx]
        img = Image.open(os.path.join(self.img_path, img_name))
        img = self.transform(img)
        try:
            label = int(img_name.split("_")[0])
        except:
            label = -1
        return img,label

(2)加载数据

import torch
import os
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

val_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# train_transform可以和test_transform不一样
train_transform= transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

batch_size = 32
dataset_dir = "./food"

train_set = FoodDataset(os.path.join(dataset_dir,"train"), transform=train_transform)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
val_set = FoodDataset(os.path.join(dataset_dir,"val"), traindform=val_transform)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

  • 2
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值