先上数据
img为GF二号影像,正常来讲有四个波段NIR和RGB,本文使用opencv读取图像,如果超过RGB三个波段的影像建议使用gdal。
label为一个0-15灰度值得图像,因为此数据是来做遥感的语义分割,所以讲label处理成这样,图像是黑色的是因为黑白影像的灰度值在计算机中一般取值为0-255,15的数值相对于255较小,所以全是黑色。
1.导入所需要的库
import os
import numpy as np
import cv2
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
2.定义我们所需要的函数
使用opencv读取图像我们需要图像的路径,所以定义一个读取图像路径的函数get_img_path
def get_img_path(img_path):
res = []
for file_name in os.listdir(img_path):
file_path = os.path.join(img_path, file_name)
res.append(file_path)
return res
之后我们使用transform模块定义图像和标签的预处理工作,在这里我仅仅将读取的图像转化成为了张量。
img_transform = transforms.Compose([transforms.ToTensor(),
])
class MaskToTensor(object):
def __call__(self, mask):
return torch.from_numpy(np.array(mask, dtype=np.int32)).long()
label_transform = MaskToTensor()
因为我的标签是0-15(15为类别数)的灰度图像所以定义了一个MaskToTensor类来转换我的标签图像,如果直接使用transforms.ToTensor他会自动将像素值归一化。
3.重写dataset
重写dataset需要我们自己定义三个属性,第一个是__init__,这里面包括了我们要输入进dataset的参数,比如img_path,label_path,img_transform, label_transform。
class RSDataset(Dataset):
def __init__(self, img_path, label_path, img_transform, label_transform):
self.img_path = img_path
self.label_path = label_path
self.img_transform = img_transform
self.label_transform = label_transform
第二个是__getitem__,这里我认为是最关键的一环,我们需要使用item这个索引来读取我们的数据(一定要用),在前面__init__已经将我们的img_path,label_path送入到了我们的类中,item就是在path中一个迭代的索引,我们无需在自己写for循环来历遍整个数据集,之后再讲图像返回到我们定义的img_transform,label_transform中。
def __getitem__(self, item):
img = cv2.imread(self.img_path[item],cv2.IMREAD_UNCHANGED)
label = cv2.imread(self.label_path[item], cv2.IMREAD_UNCHANGED)
return self.img_transform(img), self.label_transform(label)
第三个是__len__,这里我需要告诉程序运行到什么时候结束。
def __len__(self):
return len(self.img_path)
这样一个class Mydataset就重写完了
4.验证我们的读取的数据
可以使用pytorch中的Dataloader来读取Mydatase,然后看一下最后的输出
img_path = r'F:\data\GF2\1\img'
label_path = r'F:\data\GF2\1\label'
img_path = get_img_path(img_path)
label_path = get_img_path(label_path)
Dataloader = torch.utils.data.DataLoader(
RSDataset(img_path, label_path, img_transform, label_transform),
batch_size=2, shuffle=True, num_workers=0, pin_memory=True
)
for i, l in Dataloader:
print(i.shape, l.shape)
break
torch.Size([2, 4, 200, 200]) torch.Size([2, 200, 200])