Pytorch自己定义Dataloader加载高光谱数据集
为了方便进行对比,这里简单说一下 图像分类中的数据集加载
ImageFolder是pytorch框架已经编好的加载数据集方法,可以直接拿来用。
但是如果我们的数据集为深度图、高光谱遥感图....这类非传统RGB图,我们就需要定义自己的加载数据集方法。
#准备好训练集
train_dateset = ImageFolder(image_path + '/train', transform=data_transform["train"])
train_loader = DataLoader(train_dateset, batch_size=batch_size, shuffle=True, num_workers=0)
)
1、 加载高光谱数据集
注意这里的dataset是模块,dataset.Dataset才是类,加载数据集需要继承Dataset这个类
首先的生成数据集只会调用 __init__
和__len__
方法。
一般就是__init__方法中传入图像和GT,以及预处理方法(将数据转换成tensor格式等等...)
import scipy.io as sio
import torch
from torch.utils.data import dataset
class my_dataset(dataset.Dataset):
def __init__(self, root, transform=None, target_transform=None):
super(my_dataset, self).__init__()
#预处理
self.transform = transform
self.target_transform = target_transform
#加载Samson数据集的mat文件(字典文件)
hsi_data = sio.loadmat(root)
#原始HSI图像[156, 9025]
training_data = hsi_data['Y']
#丰度标签[3, 9025]
labels = hsi_data['A']
self.train_data = torch.reshape(torch.from_numpy(training_data), (156, 95, 95))
self.labels = torch.reshape(torch.from_numpy(labels), (3, 95, 95))
def __getitem__(self, index):
img, target = self.train_data[index], self.labels[index]
if self.transform is not None:
img = torch.tensor(img)
if self.target_transform is not None:
target = torch.tensor(target)
return img, target
def __len__(self):
return len(self.train_data)
# return 1
但是__len__
方法有一些讲究。比如我这里之间将原始mat文件里的图像和GT读出来,全部reshape
成[channel, width, height]。这个时候返回的train_dataset 的数量取决于__len__
方法。
这里以Samson数据集为例,如果想要将整个[156, 95, 95]的图片当成一个图片进行后续的训练,那么需要在__len__
方法直接return1。
但是如果我们需要比如逐像素进行输入,就不能够直接return1。每个像素可以看作是一个1-D的向量,把每一个像素当成一个输入数据的话,是不是就是原始图像中[156, 9025]进行转置–>[9025, 156],逐个去取9025的每一个?
self.train_data = training_data.T #[9025, 156]
self.labels = labels.T #[9025, 3]
__len__
中 return len(self.train_data)
可以看到train_dataset是一个9025个像素组成的数据集。
label
是9025个(3,), img
是9025个(156,),相当于取每一个像素作为一个训练数据。后面根据batch_size的大小划分这9025个数据就行。
// 接着就是my_dataset类实例化对象,train_dataset 就是我们自己高光谱图像的数据集了。
train_dataset = my_dataset(root="samson_dataset.mat")
2、 送进Dataloder加载
DataLoader按照 batch_size=20
将train_dataset 划分。比如这里如果逐像素进行输入,一共9025个1-D向量数据,每一个batch训练20个数据,一轮epoch要训练完,需要迭代9025/20=452次。
source_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=20)
开始迭代训练
for i, (x, y) in enumerate(source_dataloader):
output = net(x).to(device)
n += 1
print(n)
这里按照batch_size开始迭代,就会访问上面创建数据集中的__getitem__
方法了。index其实就是逐个索引访问每个像素,每batch_size=20个为一次。
所以上面的代码中的n=9025/20=4025.
def __getitem__(self, index):
img, target = self.train_data[index], self.labels[index]
if self.transform is not None:
img = torch.tensor(img)
if self.target_transform is not None:
target = torch.tensor(target)
return img, target
// 高光谱,我目前涉及的是无监督的解混方法,这里batch_size就直接设为1了。
将数据集送进Dataloder,加载我们的数据集。
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset, batch_size=1, shuffle=False
)
循环遍历train_loader,这里因为输入为1个156x95x95的Samson数据集,idx只有一个0的索引,返回图像和GT。
i遍历train_loader,是一个包含两个元素的元组,i[0]就是原始输入HSI了,第一个维度是batch,i[1]就是GT啦。