这里使用gdal库
下面是将tiff格式转为tensor格式的模块
将此模块放在一个python文件下
import torch
import os
import numpy as np
from osgeo import gdal
from torch.utils.data import Dataset
# 定义获取文件的方法
class WaterDataSet(Dataset):
def __init__(self, images_dir, labels_dir):
self.images = self.read_muitiband_images(images_dir) # 调用下面的方法加载遥感影像
self.labels = self.read_water_labels(labels_dir) # 调用下面的方法加载加载数据集
# 加载images 方法
def read_muitiband_images(self,images_dir):
images = [] # 为images 创建一个空数组
imgs = os.listdir(images_dir)
for img in imgs: # 遍历整个包含image的文文件夹
filetype = os.path.splitext(img)[-1] # filetype 返回文件的格式
if filetype == '.tif': # 判断是否为tiff格式的图像
img_path = os.path.join(images_dir, img) # 如果是tiff,将文件路径拼接
rsdl_data = gdal.Open(img_path) # 采用gdal.Open()的方法打开文件
# 将数据源的各个波段堆叠为一个数组,np.stack()用于堆叠数组, axis=0 表示按照第一个维度,即波段维度
#[……]中,rsdl_data 为需要处理的遥感影像;
# .GetRasterBand(i).ReadAsArray() for…… 遍历图像的每个波段,将波段数据以二维数组的形式输出
images.append(np.stack([rsdl_data .GetRasterBand(i).ReadAsArray() for i in range(1,7)],axis=0))
return images
# 加载labels的方法,与读取images类似
def read_water_labels(self,labels_dir):
labels = []
labs = os.listdir(labels_dir)
for lab in labs:
filetype = os.path.splitext(lab)[-1]
if filetype == '.tif':
lab_path = os.path.join(labels_dir, lab)
rsdl_data = gdal.Open(lab_path)
labels.append(np.stack(rsdl_data .GetRasterBand(1).ReadAsArray()))
return labels
# 返回数据集长度
def __len__(self):
return len(self.images)
# 实现定义数据的索引操作,并将图像返回给tensor张亮
def __getitem__(self, idx):
image = self.images[idx]
label = self.labels[idx]
return torch.tensor(image), torch.tensor(label)
读取该模块,导入文件
from reading import *
from torch.utils.data import DataLoade, Subset
# 获取数据集的影像和标签
images_dir = 'your images path'
labels_dir = 'your labels path'
# 将数据集使用WaterDataset()类实现对数据集格式的转换
dataset = WaterDataSet(images_dir, labels_dir)
len_deatset=len(dataset) # 获取数据集的长度
# 将数据集分为训练集和测试集
trainset = Subset(dataset, list(range(len_deatset//5*4)))
testset = Subset(dataset, list(range(len_deatset//5*4, len_deatset)))
# 每次使用一个样本进行数据更新,随机抓取数据
trainloader = DataLoader(trainset, batch_size=1, shuffle=True)
testloader = DataLoader(testset, batch_size=1, shuffle=True)
调式代码,可以看到结果