Pytorch导入.mat数据集
前提:
数据集是图像和对应的标签,具体来说,我的数据集在和.py文件位于同一目录下的Training and Verification datasets文件夹中,文件下有Input of CNN1和Output of CNN1两个文件夹,顾名思义,Input of CNN1文件夹里是CNN1的输入,里面的文件是LeftImage2.mat.mat,LeftImage3.mat.mat一直到LeftImage1001.mat.mat。Output of CNN1文件夹里是CNN1的输出,里面的文件是D2.mat.mat,D3.mat.mat一直到D1001.mat.mat。还有M2.mat.mat,M3.mat.mat一直到M1001.mat.mat。
数据集的文件结构如下:
MyDatasetFolder
├── Input of CNN1
···├── LeftImage2.mat.mat
···├── LeftImage3.mat.mat
···├── --------
└── Output of CNN1
···├── D2.mat.mat
···├── --------
···├── M1.mat.mat
···├── --------
正文:
从最高层级开始,这段代码的目标是从一个由 .mat
文件组成的数据集中加载图像数据,然后将这些数据转换为 PyTorch 张量,以便它们可以被神经网络使用。为了实现这个目标,代码主要分为几个部分:
- 定义数据集类 (
MyDataset
):该类继承自torch.utils.data.Dataset
,是 PyTorch 中用于表示数据集的标准接口。在这个类中,需要定义__len__()
方法(返回数据集中的样本数)和__getitem__()
方法(返回给定索引的样本)。这使得可以像处理列表一样处理MyDataset
的实例。 - 定义转换类 (
ToTensor
和Normalize
):这些类是用来预处理数据的。在 PyTorch 中,通常会创建一个转换的流水线,它由多个转换类组成,每个类都实现了__call__()
方法,这样它们就可以像函数一样被调用。流水线的输入是一个样本(在这个例子中,是一个包含输入图像和两种输出图像的字典),流水线的输出是转换后的样本。 - 实例化数据集和数据加载器:通过实例化
MyDataset
类并传入转换流水线,创建了一个数据集对象,然后可以将这个对象传给torch.utils.data.DataLoader
来创建一个数据加载器对象。数据加载器可以将数据集中的样本分批(batch),打乱,以及并行加载。 - 加载并显示图像:最后,为了验证这个工作,从数据加载器中获取了一批数据,然后选择了一张输入图像和对应的两种输出图像进行显示。
具体内容:
loadmat_v73(file_name, variable_name)
函数:这个函数用于加载 .mat
文件(MATLAB 格式的文件),并返回文件中的某个变量。这个函数是为了让 Python 能够读取 MATLAB 文件而设计的,因为数据集显然是用 MATLAB 生成的。
**MyDataset
类:**这个类扩展了 PyTorch 的 Dataset
基类,该基类定义了访问数据集的基本接口。MyDataset
类在初始化时需要一个根目录(其中包含输入和输出图像的文件夹),以及一个可选的转换流水线。
在 __init__
方法中,我们获取了输入图像和输出图像的文件名列表。这里使用了 Python 的内置 sorted
函数和 os.listdir
函数,以确保图像文件按升序排列。这是因为在数据集中,图像文件的命名可能是按照它们的序号来命名的(例如,“Image1.mat”, “Image2.mat”, “Image3.mat”, 等等),所以需要确保它们是按正确的顺序加载的。
__len__
方法返回数据集中的样本数量。在这个例子中,它只是返回输入图像列表的长度。
__getitem__
方法返回给定索引的样本。首先,它检查索引是否是一个 PyTorch 张量,如果是的话,就将其转换为一个 Python 列表。然后,它获取对应索引的输入图像和输出图像的文件路径,加载图像,创建一个包含输入图像和输出图像的字典,然后如果存在转换流水线,就将其应用于样本。
ToTensor
类是一个转换类,它的作用是将 numpy 数组转换为 PyTorch 张量,然后将图像的像素值从 [0, 255] 缩放到 [0, 1],并添加一个通道维度。这是因为 PyTorch 的卷积层需要输入有一个通道维度,即使是灰度图像也需要通道维度。
Normalize
类是另一个转换类,它的作用是对图像进行归一化,减去均值并除以标准差。这是一个常见的预处理步骤,可以帮助神经网络更好地学习。
最后,我们实例化了 MyDataset
类,创建了一个数据加载器,然后从数据加载器中获取了一批数据,并显示了一张输入图像和对应的两种输出图像。
部分代码:
def loadmat_v73(file_name, variable_name):
with h5py.File(file_name, 'r') as file:
data = np.array(file[variable_name])
return data
class MyDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.input_images = sorted(os.listdir(os.path.join(root_dir, "Input of CNN1")),
key=lambda x: int(re.search(r'\d+', x).group()))
all_output_files = sorted(os.listdir(os.path.join(root_dir, "Output of CNN1")),
key=lambda x: int(re.search(r'\d+', x).group()))
self.output_images_D = [file for file in all_output_files if file.startswith('D')]
self.output_images_M = [file for file in all_output_files if file.startswith('M')]
def __len__(self):
return len(self.input_images)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
input_image_path = os.path.join(self.root_dir, "Input of CNN1", self.input_images[idx])
output_image_D_path = os.path.join(self.root_dir, "Output of CNN1", self.output_images_D[idx])
output_image_M_path = os.path.join(self.root_dir, "Output of CNN1", self.output_images_M[idx])
input_image = loadmat_v73(input_image_path, 'LeftImage')
output_image_D = loadmat_v73(output_image_D_path, 'D')
output_image_M = loadmat_v73(output_image_M_path, 'M')
sample = {'input': input_image, 'output_D': output_image_D, 'output_M': output_image_M}
if self.transform:
sample = self.transform(sample)
return sample
class Normalize(object):
"""Normalize a tensor image with mean and standard deviation."""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, sample):
input_image, output_image_D, output_image_M = sample['input'], sample['output_D'], sample['output_M']
input_image = transforms.Normalize(self.mean, self.std)(input_image)
output_image_D = transforms.Normalize(self.mean, self.std)(output_image_D)
output_image_M = transforms.Normalize(self.mean, self.std)(output_image_M)
return {'input': input_image,
'output_D': output_image_D,
'output_M': output_image_M}