1.有N种类别的图片分别放在N个文件夹里,每个文件夹放一类标签
(1)首先将图片划分为训练集和验证集
import os
from shutil import copy, rmtree
import random
def mk_dir(file_path: str):
# 如果创建的文件夹存在递归删除整个文件夹下所有文件,包括此文件夹
if os.path.exists(file_path):
rmtree(file_path)
os.makedirs(file_path)
def main(data_dir, new_dir):
random.seed(0)
# 以训练集:测试集 = 7:3 的比例划分
train_rate = 0.7
# 存放所有类图片的地址
assert os.path.exists(data_dir), "该地址不存在"
# 获取所有的类别名称
class_list = os.listdir(data_dir)
# 创建train和val文件夹,每个文件夹里再创建所有类别的文件
# 这里为了保留原来的数据集因此新建了一个文件存放分好的图片
mk_dir(new_dir)
train_dir = os.path.join(new_dir, "train")
val_dir = os.path.join(new_dir, "val")
mk_dir(train_dir)
for cls in class_list:
# 图片新的存放地址
train_path = os.path.join(train_dir, cls)
mk_dir(train_path)
val_path = os.path.join(val_dir, cls)
mk_dir(val_path)
# 当前处理类的文件夹
cls_path = os.path.join(data_dir, cls)
# 获取某一个类别图片的所有图片名
all_img = os.listdir(cls_path)
num_img = len(all_img)
img_idx = list(range(0, num_img))
# 随机打乱索引
random.shuffle(img_idx)
num_train = int(num_img * train_rate)
# 遍历所有当前处理类的所有图片
for idx in range(num_img):
if idx < num_train:
copy(os.path.join(cls_path, all_img[img_idx[idx]]), train_path)
else:
copy(os.path.join(cls_path, all_img[img_idx[idx]]), val_path)
print("\r正在处理{}类:{}/{}".format(cls, idx+1, num_img), end="")
print("\n训练集,验证集划分完成!!!")
if __name__ == '__main__':
data_dir = './afterResizeVehcileClassificationDataset/'
new_dir = './newvehcileClassificationDataset/'
main(data_dir, new_dir)
划分完成后数据文件变为:
(2)读取数据
import torch
from torchvision import transforms, datasets, utils
import json
data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
img_path = './newvehcileClassificationDataset'
# 加载训练集
train_dataset = datasets.ImageFolder(root=img_path + '/train',transform=data_transform)
cls_list = train_dataset.class_to_idx
cls_dict = dict((val, key) for key, val in cls_list.items())
# 将python对象编码成Json字符串
json_str = json.dumps(cls_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
# 加载验证集
val_dataset = datasets.ImageFolder(root=img_path + '/val', transform=data_transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
2.所有的图片都在一个文件内类别名保存在图片名中
(1)自定义数据类型
from PIL import Image
from torch.utils.data import Dataset
import os
class MyDataset(Dataset):
def __init__(self,img_path,transform):
super(MyDataset).__init__()
self.img_path = img_path
# 获取文件夹下的所有图片名称
self.all_names = [x for x in os.listdir(img_path) if x.endswith(".jpg")]
self.transform = transform
def __len__(self):
return len(self.all_names)
def __getitem__(self,idx):
img_name = self.all_names[idx]
img = Image.open(os.path.join(self.img_path, img_name))
img = self.transform(img)
try:
label = int(img_name.split("_")[0])
except:
label = -1
return img,label
(2)加载数据
import torch
import os
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
val_transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
])
# train_transform可以和test_transform不一样
train_transform= transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
])
batch_size = 32
dataset_dir = "./food"
train_set = FoodDataset(os.path.join(dataset_dir,"train"), transform=train_transform)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
val_set = FoodDataset(os.path.join(dataset_dir,"val"), traindform=val_transform)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)