1. for 循环遍历每个文件夹
import os
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
class MyDataSet(Dataset): # 定义 class MyDataSet
def __init__(self,datapath):
super(MyDataSet, self).__init__()
imgs = []
i = 0 # 分类label
for path in datapath:
for file in os.listdir(path):
if file == 'desktop.ini': # 遇到隐藏文件直接跳过
continue
#print(file)
img = Image.open(path+"\\"+file) # 遍历文件夹内的图片
img = img.resize((32,32),Image.ANTIALIAS) # 将所有图片转化为相同大小(32,32)
img = transforms.ToTensor()(img) # 将img 转化为tensor 方便使用pytorch框架
imgs.append((img,i)) # 生成(img,label) 的字典列表
i += 1 #用作标签label
self.imgs = imgs
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
return self.imgs[index]
# 定义data path 列表
data_path = ["C:\\Users\\KilLogic\\summer-2023\\HW3\\data\\车辆分类数据集\\bus",
"C:\\Users\\KilLogic\\summer-2023\\HW3\\data\\车辆分类数据集\\car",
"C:\\Users\\KilLogic\\summer-2023\\HW3\\data\\车辆分类数据集\\truck"]
data = MyDataSet(data_path)
train_size = 800 # 1567 张图片 800 作为train 剩下的作为test
test_size = len(data)-train_size
# 将data 随机分开作为训练集 和测试集
train_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, test_size])
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=0)
test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=0)
2. 使用dataset.ImageFolder
使用该函数之后能直接生成一个字典表。
- 遍历key,value,将key复制到img
- 将img resize(将所有图片大小统一)
- 将img 转化为Tensor
- 将处理过的数据组成imgs字典(img,value)
- 之后生成data_loader,每个data_loader 即为 imgs 字典的迭代器
from torchvision import datasets
data = datasets.ImageFolder(root = "C:\\Users\\KilLogic\\summer-2023\\HW3\\data\\车辆分类数据集")
for i in range(len(data)):
print(i,data[i])
output
0 (<PIL.Image.Image image mode=RGB size=85x120 at 0x14703C438B0>, 0)
1 (<PIL.Image.Image image mode=RGB size=140x195 at 0x14703C43A00>, 0)
2 (<PIL.Image.Image image mode=RGB size=86x93 at 0x14703C43A30>, 0)
3 (<PIL.Image.Image image mode=RGB size=152x205 at 0x14703C439A0>, 0)
4 (<PIL.Image.Image image mode=RGB size=150x174 at 0x14703C43A00>, 0)
5 (<PIL.Image.Image image mode=RGB size=85x108 at 0x14703C43A30>, 0)
6 (<PIL.Image.Image image mode=RGB size=60x85 at 0x14703C439A0>, 0)
7 (<PIL.Image.Image image mode=RGB size=95x142 at 0x14703C43A00>, 0)
8 (<PIL.Image.Image image mode=RGB size=145x205 at 0x14703C43A30>, 0)
9 (<PIL.Image.Image image mode=RGB size=84x119 at 0x14703C439A0>, 0)
class MyDataSet(Dataset):
def __init__(self,datapath):
super(MyDataSet, self).__init__()
dicts = datasets.ImageFolder(root = datapath)
imgs = []
i = 0
for key,value in dicts:
img = key
img = img.resize((32,32),Image.ANTIALIAS)
img = transforms.ToTensor()(img)
imgs.append((img,value))
self.imgs = imgs
print(imgs)
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
pic,label = self.imgs[index]
return self.imgs[index]
data_path = "C:\\Users\\KilLogic\\summer-2023\\HW3\\data\\车辆分类数据集"
data = MyDataSet(data_path)
train_size = 800
test_size = len(data)-train_size
train_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, test_size])
#print(len(train_dataset),len(test_dataset),len(data))
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=0)
test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=0)
3. 后续添加
能力有限,欢迎吐槽交流。