(1)获取数据集
数据集的获取地址:
https://www.kaggle.com/datasets/araraltawil/fruit-101-dataset/data
https://www.kaggle.com/datasets/olgabelitskaya/flower-color-images
也可以直接通过下面获取:
通过网盘分享的文件:花朵数据集和水果数据集(自定义数据集))
链接: https://pan.baidu.com/s/1bI6iY_6_ovPAY68MfINLUg?pwd=u8wr 提取码: u8wr
--来自百度网盘超级会员v5的分享
将上述下载的两个文件放到自己运行的这个目录dataset文件夹下(新建的)
(2)代码实现1-水果数据集
①导包
#第一部分:导包
# 导入必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
from torchvision import datasets, transforms
②定义数据转换方法
#第二部分:定义数据转换方法
transform = transforms.Compose([
transforms.Resize((128, 128)), # 调整图像大小为128x128
transforms.ToTensor(), # 将数据转换为张量
])
# 创建图像数据集
# ImageFolder类会自动遍历指定目录下的所有子目录
# 并将每个子目录中的图像文件视为同一类别的数据
dataset = ImageFolder('./dataset/fruit_101/', transform=transform)
③数据集的基本操作
#第三部分:数据集的基本操作
#长度
print(len(dataset))
#种类(文件夹的名称)
print(dataset.classes)
#打印每个类别对应的编号
print(dataset.class_to_idx)
④定义绘图函数并打印
#第四部分:定义绘图函数并打印
# 定义绘图函数,传入dataset即可
def plot(dataset, shuffle=True):
# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=shuffle)
# 取出一组数据
images, labels = next(iter(dataloader))
# 将通道维度(C)移动到最后一个维度,方便使用matplotlib绘图
images = np.transpose(images, (0, 2, 3, 1))
# 创建4x4的子图对象
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(8, 8))
# 遍历每个子图,绘制图像并添加子图标题
for i, ax in enumerate(axes.flat):
ax.imshow(images[i])
ax.axis('off') # 隐藏坐标轴
if hasattr(dataset, 'classes'): # 如果数据集集有预定义的类别名称,使用该名称作为子图标题
ax.set_title(dataset.classes[labels[i]], fontsize=12)
else: # 否则使用类别索引作为子图标题
ax.set_title(labels[i], fontsize=12)
plt.show()
plot(dataset)
⑤完整pycharm代码实现
#第一部分:导包
# 导入必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
from torchvision import datasets, transforms
#第二部分:定义数据转换方法
transform = transforms.Compose([
transforms.Resize((128, 128)), # 调整图像大小为128x128
transforms.ToTensor(), # 将数据转换为张量
])
# 创建图像数据集
# ImageFolder类会自动遍历指定目录下的所有子目录
# 并将每个子目录中的图像文件视为同一类别的数据
dataset = ImageFolder('./dataset/fruit_101/', transform=transform)
#第三部分:数据集的基本操作
#长度
print(len(dataset))
#种类(文件夹的名称)
print(dataset.classes)
#打印每个类别对应的编号
print(dataset.class_to_idx)
#第四部分:定义绘图函数并打印
# 定义绘图函数,传入dataset即可
def plot(dataset, shuffle=True):
# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=shuffle)
# 取出一组数据
images, labels = next(iter(dataloader))
# 将通道维度(C)移动到最后一个维度,方便使用matplotlib绘图
images = np.transpose(images, (0, 2, 3, 1))
# 创建4x4的子图对象
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(8, 8))
# 遍历每个子图,绘制图像并添加子图标题
for i, ax in enumerate(axes.flat):
ax.imshow(images[i])
ax.axis('off') # 隐藏坐标轴
if hasattr(dataset, 'classes'): # 如果数据集集有预定义的类别名称,使用该名称作为子图标题
ax.set_title(dataset.classes[labels[i]], fontsize=12)
else: # 否则使用类别索引作为子图标题
ax.set_title(labels[i], fontsize=12)
plt.show()
plot(dataset)
(3)代码实现2-花朵数据集
①导包
#第一部分:导包
# 导入必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
import os
from PIL import Image # pip install Pillow
from torch.utils.data import Dataset
②定义数据转换方法
#第二部分:定义数据转换方法
# 定义数据转换方法
transform = transforms.Compose([
transforms.Resize((128, 128)), # 调整图像大小为128x128
transforms.ToTensor(), # 将数据转换为张量
])
③定义flower类
#第三部分:定义flower类
class Flowers(Dataset):
def __init__(self, data_dir, transform=None):
self.image_paths = []
self.labels = []
self.transform = transform
# 遍历数据集目录, 获取所有图像文件的路径和标签
for filename in sorted(os.listdir(data_dir)):
image_path = os.path.join(data_dir, filename)
label = int(filename.split('_')[0])
self.image_paths.append(image_path)
self.labels.append(label)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
# 加载图像数据和标签
image = Image.open(self.image_paths[idx])
label = self.labels[idx]
# 对图像数据进行转换
if self.transform:
image = self.transform(image)
# 将标签转换为PyTorch张量
label = torch.tensor(label, dtype=torch.long)
return image, label
④随机选择一个批次的 16 张图片进行绘制
# 定义绘图函数,传入dataset即可
def plot(dataset, shuffle=True):
# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=shuffle)
# 取出一组数据
images, labels = next(iter(dataloader))
# 将通道维度(C)移动到最后一个维度,方便使用matplotlib绘图
images = np.transpose(images, (0, 2, 3, 1))
# 创建4x4的子图对象
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(8, 8))
# 遍历每个子图,绘制图像并添加子图标题
for i, ax in enumerate(axes.flat):
ax.imshow(images[i])
ax.axis('off') # 隐藏坐标轴
if hasattr(dataset, 'classes'): # 如果数据集集有预定义的类别名称,使用该名称作为子图标题
ax.set_title(dataset.classes[labels[i]], fontsize=12)
else: # 否则使用类别索引作为子图标题
ax.set_title(labels[i], fontsize=12)
plt.show()
plot(dataset)
⑤输出一部分图片(前16张)
#第五部分:输出一部分图片(前16张)
#取出一部分图片
from torch.utils.data import Subset
dataset = Flowers('./dataset/flower_color/flowers/flowers', transform=transform)
subset = Subset(dataset, [i for i in range(16)])
plot(subset, False)
⑥完整pycharm代码
#第一部分:导包
# 导入必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
import os
from PIL import Image # pip install Pillow
from torch.utils.data import Dataset
#第二部分:定义数据转换方法
# 定义数据转换方法
transform = transforms.Compose([
transforms.Resize((128, 128)), # 调整图像大小为128x128
transforms.ToTensor(), # 将数据转换为张量
])
#第三部分:定义flower类
class Flowers(Dataset):
def __init__(self, data_dir, transform=None):
self.image_paths = []
self.labels = []
self.transform = transform
# 遍历数据集目录, 获取所有图像文件的路径和标签
for filename in sorted(os.listdir(data_dir)):
image_path = os.path.join(data_dir, filename)
label = int(filename.split('_')[0])
self.image_paths.append(image_path)
self.labels.append(label)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
# 加载图像数据和标签
image = Image.open(self.image_paths[idx])
label = self.labels[idx]
# 对图像数据进行转换
if self.transform:
image = self.transform(image)
# 将标签转换为PyTorch张量
label = torch.tensor(label, dtype=torch.long)
return image, label
#第四部分:定义绘图函数并打印
# 创建图像数据集
# ImageFolder类会自动遍历指定目录下的所有子目录
# 并将每个子目录中的图像文件视为同一类别的数据
dataset = Flowers('./dataset/flower_color/flowers/flowers', transform=transform)
# 定义绘图函数,传入dataset即可
def plot(dataset, shuffle=True):
# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=shuffle)
# 取出一组数据
images, labels = next(iter(dataloader))
# 将通道维度(C)移动到最后一个维度,方便使用matplotlib绘图
images = np.transpose(images, (0, 2, 3, 1))
# 创建4x4的子图对象
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(8, 8))
# 遍历每个子图,绘制图像并添加子图标题
for i, ax in enumerate(axes.flat):
ax.imshow(images[i])
ax.axis('off') # 隐藏坐标轴
if hasattr(dataset, 'classes'): # 如果数据集集有预定义的类别名称,使用该名称作为子图标题
ax.set_title(dataset.classes[labels[i]], fontsize=12)
else: # 否则使用类别索引作为子图标题
ax.set_title(labels[i], fontsize=12)
plt.show()
#plot(dataset)
#第五部分:输出一部分图片(前16张)
#取出一部分图片
from torch.utils.data import Subset
dataset = Flowers('./dataset/flower_color/flowers/flowers', transform=transform)
subset = Subset(dataset, [i for i in range(16)])
plot(subset, False)
(4)上述两种加载数据的方法对比
第一种方法(不需要自定义类)
第一种方法直接使用了 ImageFolder
类,它来自 torchvision.datasets
,这是一个专门为图像分类任务设计的类,适用于数据集中包含多个类别,每个类别都有一个独立的子文件夹的情况。ImageFolder
会自动读取每个子文件夹名称作为类别标签,将其分配给该类别中的所有图像文件。因此,这种方法适用于以下数据组织结构:
在这个结构中,每个子文件夹 (apple
, banana
, orange
) 都代表一个类别,而 ImageFolder
会自动将每个子文件夹的名称映射为标签。由于 ImageFolder
提供了简便的目录读取和标签分配,不需要自定义数据集类。
第二种方法(自定义 Flowers
类)
第二种方法定义了一个自定义的 Flowers
类。需要自定义类的原因通常有以下几种情况:
-
文件命名格式不适用于
ImageFolder
:如果数据集中的图像没有按文件夹组织,而是通过文件名来区分类别,例如1_flower.jpg
,2_flower.jpg
等。此时,每个图像的类别信息来自文件名,而不是所在的子文件夹。 -
特殊的数据加载逻辑:有时候,数据集中的图像或标签格式不适合
ImageFolder
,需要自定义读取和标签提取逻辑。在代码中,自定义Flowers
类用于读取文件名中的数字部分作为标签。
例如,以下目录结构:
在这种情况下,ImageFolder
无法直接解析文件名中的类别,因此需要自定义 Flowers
类,通过文件名提取标签。代码中的 label = int(filename.split('_')[0])
就是这样做的,它将文件名的数字部分作为标签。
总结
- 第一种方法 适用于 结构化良好的数据集,即每个类别一个子文件夹,这种情况下
ImageFolder
就足够用了。 - 第二种方法 适用于 文件名中包含类别信息 的数据集,或 数据组织不适合
ImageFolder
的情况,需要手动从文件名或其他信息中提取标签。这时,自定义类可以提供灵活的数据读取和标签分配。