目录
1.数据准备部分(图像分类问题)(就是一个img对应一个label)
2.使用torch.utils.data.Dataset 继承类的方法
1.数据准备部分(图像分类问题)(就是一个img对应一个label)
数据的准备主要分为两类,一类是直接利用torchvision.datasets.XXX,导入pytorch里最常用的一些数据集。如CIFAR10等。这类数据集一般需要下载后才能使用。
1.引入pyTorch中自带数据集
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10("./cifar10",train=False,transform=torchvision.transforms.ToTensor(),download=True)
train_data_size = len(train_data)
print(f"训练数据集的长度为{train_data_size}")
test_data_size = len(test_data)
print(f"测试数据集的长度为{test_data_size}")
程序最后利用了len函数打印数据集长度,优化了程序。
获取了数据集以后,再利用Dataloader函数读写数据集,才能放入神经网络使用。
#利用dataloader加载数据集
train_dataloder = DataLoader(train_data,batch_size=64)
test_dataloder = DataLoader(test_data,batch_size=64)
2.第二类数据集的准备是应对自己准备的数据集。
1.使用ImageFolder、Dataloader 函数
这两个函数均来自于torchvision.datasets中。对于这两个函数的理解,ImageFloder作为数据读取器,并且能在读取数据时对数据进行初始化。
实现代码如下:
————————————————定义数据目录————————————————
train_dir = "../data/hotdog/train"
test_dir = "../data/hotdog/test"
————————————————定义预处理的参数————————————
# 将图像调整为224×224尺寸并归一化
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
train_augs = transforms.Compose([
transforms.RandomResizedCrop(size=224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
test_augs = transforms.Compose([
transforms.Resize(size=256),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
————————————————使用ImageFolder函数——————————
train_set = datasets.ImageFolder(train_dir, transform=train_augs) #两个输入,一个是图像目录
test_set = datasets.ImageFolder(test_dir, transform=test_augs)#一个是预处理参数,上面都定义了。
batch_size = 32
train_iter = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_iter = DataLoader(test_set, batch_size=batch_size)
ImageFolder会将目录中的文件夹名自动转化成序列,当DataLoader载入时,标签自动就是整数序列了。比如说目录里的文件夹是郁金香,百合。生成的targets对应就是0,1。可以通过print(dataset.class_to_idx)这个代码查看标签对应关系:
train_set = datasets.ImageFolder(train_dir, transform=train_augs) #两个输入,一个是图像目录
test_set = datasets.ImageFolder(test_dir, transform=test_augs)#一个是预处理参数,上面都定义了。
print(f"标签对应关系为:{train_set.class_to_idx}")
输出结果:标签对应关系为:{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
Dataloader构建可迭代的数据装载器,具体的实现步骤见下链接,详细阐述了Dataloader的运行机制。系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)
2.使用torch.utils.data.Dataset 继承类的方法
通过继承torch.utils.data.Dataset实现用户自定义读取数据集。必须要重写__init__ ,以及__getitem__和__len__ 魔术方法(在执行程序的时候会自动执行)。
重写原因是父类Dataset类中,__getitem__定义为
def __getitem__(self, index) -> T_co:
raise NotImplementedError
__init__:一般包含根目录、文件名。师兄还写了一个transform,应该是能提供和ImageFloder里transform一样功能。
__len__:返回整个数据集的数量,这一类方法可以在后面的Dataloader中被调用。使用Dataloader加载数据时,如果不知道数据集的长度,可能一些功能无法使用(我猜比如drop_lost,如果不知道数据集多长,就会不知道最后要取多少个)
__getitem__:根据索引读取数据,对数据进行预处理,返回数据对(img,label)。这个魔术方式是该数据集定义方法的核心方法。在这里完成对图像路径的拼接,并组合成一个列表。
自己的理解:运行Dataloader时,会自动运行__getitem__这个魔术方法,至于里面的图片怎么打开,标签怎么返回都可以自己自由发挥。其次这个参量idx就是Dataloader迭代的东西,按以下代码写。
实现步骤如下:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
import os
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
root_dir = "data_set/flower_data/train"
daisy_label_dir = "daisy" # 文件夹名字
dandelion_label_dir = "dandelion"
roses_label_dir = "roses"
sunflowers_label_dir = "sunflowers"
tulips_label_dir = "tulips"
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)#拼接路径
self.img_path = os.listdir(self.path)#把路径下的图片做成一个列表
def __getitem__(self, idx):
————————————————获取根据索引获取图像及其标签————————————————
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
—————————————————————————————————————————————————————————
————————————————————————预处理————————————————————————————
img = transforms.Compose([
transforms.RandomResizedCrop(size=224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])(img)
——————————————————————————————————————————————————————————
——————————————————写标签怎么返回————————————————————————————
'''
下面的操作就返回了标签索引型的target 可以直接往随即交叉熵里放了
'''
if self.label_dir == daisy_label_dir:
label = 0
if self.label_dir == dandelion_label_dir:
label = 1
if self.label_dir == roses_label_dir:
label = 2
if self.label_dir == sunflowers_label_dir:
label = 3
if self.label_dir == tulips_label_dir:
label = 4
return img, label
——————————————————————————————————————————————————————————
def __len__(self):
return len(self.img_path)
——————————————————定义实例化对象————————————————————
daisy_dataset = MyData(root_dir, daisy_label_dir)
dandelion_dataset = MyData(root_dir, dandelion_label_dir)
roses_dataset = MyData(root_dir, roses_label_dir)
sunflowers_dataset = MyData(root_dir, sunflowers_label_dir)
tulips_dataset = MyData(root_dir, tulips_label_dir)
———————————————————————————————————————————————————
—————————————把他们加一起作为训练数据集————————————————
train_dataset = dandelion_dataset + daisy_dataset + roses_dataset + tulips_dataset + sunflowers_dataset
—————————————————————————————————————————————————————
train_iter = DataLoader(train_dataset,batch_size=32,shuffle=True)
通过对继承Dataset这类方法的研究学习,发现使用这类方法可以对数据集的抓取以及处理有更高的灵活性,不止可以应对图像分类问题,也可以用于回归问题等的数据集引入。完成对图像分类问题的Dataset方法引入后,需专门再写一篇如何使用继承Dataset方法引入各种数据集的总结。
优化方法:对数据集命名不做要求:
import torch.utils.data as data
from PIL import Image
import os
from torchvision import transforms
class train_data(data.Dataset):
def __init__(self,train_path,label_path):
self.train_path = train_path
self.label_path = label_path
self.data = []
file_list = os.listdir(self.train_path)
for i in file_list:
label_name = os.path.join(self.label_path, i)
data_name = os.path.join(self.train_path, i)
self.data.append([data_name, label_name])
self.trans = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor()
])
def __getitem__(self, index):
data_path, label_path = self.data[index]
data = Image.open(data_path)
label = Image.open(label_path)
data = self.trans(data)
label = self.trans(label)
return data,label
def __len__(self):
return len(os.listdir(self.train_path))
2.搭建神经训练网络
1.利用pyTorch中现成的神经网络
pyTorch给我们提供了很多现成的神经网络使用,例如vgg网络、AlexNet等。