pytorch中如何重写dataset-以tif图像数据为例

先上数据

img为GF二号影像,正常来讲有四个波段NIR和RGB,本文使用opencv读取图像,如果超过RGB三个波段的影像建议使用gdal。

label为一个0-15灰度值得图像,因为此数据是来做遥感的语义分割,所以讲label处理成这样,图像是黑色的是因为黑白影像的灰度值在计算机中一般取值为0-255,15的数值相对于255较小,所以全是黑色。

1.导入所需要的库

import os
import numpy as np
import cv2
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset 
 

 2.定义我们所需要的函数

使用opencv读取图像我们需要图像的路径,所以定义一个读取图像路径的函数get_img_path

def get_img_path(img_path):
    res = []
    for file_name in os.listdir(img_path):
        file_path = os.path.join(img_path, file_name)
        res.append(file_path)

    return res

之后我们使用transform模块定义图像和标签的预处理工作,在这里我仅仅将读取的图像转化成为了张量。

img_transform = transforms.Compose([transforms.ToTensor(),
                                    ])
class MaskToTensor(object):
    def __call__(self, mask):
        return torch.from_numpy(np.array(mask, dtype=np.int32)).long()


label_transform = MaskToTensor()

 因为我的标签是0-15(15为类别数)的灰度图像所以定义了一个MaskToTensor类来转换我的标签图像,如果直接使用transforms.ToTensor他会自动将像素值归一化。

3.重写dataset

重写dataset需要我们自己定义三个属性,第一个是__init__,这里面包括了我们要输入进dataset的参数,比如img_path,label_path,img_transform, label_transform。

class RSDataset(Dataset):
    def __init__(self, img_path, label_path, img_transform, label_transform):
        self.img_path = img_path
        self.label_path = label_path
        self.img_transform = img_transform
        self.label_transform = label_transform

第二个是__getitem__,这里我认为是最关键的一环,我们需要使用item这个索引来读取我们的数据(一定要用),在前面__init__已经将我们的img_path,label_path送入到了我们的类中,item就是在path中一个迭代的索引,我们无需在自己写for循环来历遍整个数据集,之后再讲图像返回到我们定义的img_transform,label_transform中。

    def __getitem__(self, item):
        img = cv2.imread(self.img_path[item],cv2.IMREAD_UNCHANGED)
        label = cv2.imread(self.label_path[item], cv2.IMREAD_UNCHANGED)

        return self.img_transform(img), self.label_transform(label)

第三个是__len__,这里我需要告诉程序运行到什么时候结束。

 def __len__(self):
        return len(self.img_path)

这样一个class Mydataset就重写完了

4.验证我们的读取的数据

可以使用pytorch中的Dataloader来读取Mydatase,然后看一下最后的输出

img_path = r'F:\data\GF2\1\img'
label_path = r'F:\data\GF2\1\label'
img_path = get_img_path(img_path)
label_path = get_img_path(label_path)
Dataloader = torch.utils.data.DataLoader(
    RSDataset(img_path, label_path, img_transform, label_transform),
    batch_size=2, shuffle=True, num_workers=0, pin_memory=True
)
for i, l in Dataloader:
    print(i.shape, l.shape)
    break


torch.Size([2, 4, 200, 200]) torch.Size([2, 200, 200])

  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值