pytorch中如何重写dataset-以tif图像数据为例

先上数据

img为GF二号影像,正常来讲有四个波段NIR和RGB,本文使用opencv读取图像,如果超过RGB三个波段的影像建议使用gdal。

label为一个0-15灰度值得图像,因为此数据是来做遥感的语义分割,所以讲label处理成这样,图像是黑色的是因为黑白影像的灰度值在计算机中一般取值为0-255,15的数值相对于255较小,所以全是黑色。

1.导入所需要的库

import os
import numpy as np
import cv2
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset 
 

 2.定义我们所需要的函数

使用opencv读取图像我们需要图像的路径,所以定义一个读取图像路径的函数get_img_path

def get_img_path(img_path):
    res = []
    for file_name in os.listdir(img_path):
        file_path = os.path.join(img_path, file_name)
        res.append(file_path)

    return res

之后我们使用transform模块定义图像和标签的预处理工作,在这里我仅仅将读取的图像转化成为了张量。

img_transform = transforms.Compose([transforms.ToTensor(),
                                    ])
class MaskToTensor(object):
    def __call__(self, mask):
        return torch.from_numpy(np.array(mask, dtype=np.int32)).long()


label_transform = MaskToTensor()

 因为我的标签是0-15(15为类别数)的灰度图像所以定义了一个MaskToTensor类来转换我的标签图像,如果直接使用transforms.ToTensor他会自动将像素值归一化。

3.重写dataset

重写dataset需要我们自己定义三个属性,第一个是__init__,这里面包括了我们要输入进dataset的参数,比如img_path,label_path,img_transform, label_transform。

class RSDataset(Dataset):
    def __init__(self, img_path, label_path, img_transform, label_transform):
        self.img_path = img_path
        self.label_path = label_path
        self.img_transform = img_transform
        self.label_transform = label_transform

第二个是__getitem__,这里我认为是最关键的一环,我们需要使用item这个索引来读取我们的数据(一定要用),在前面__init__已经将我们的img_path,label_path送入到了我们的类中,item就是在path中一个迭代的索引,我们无需在自己写for循环来历遍整个数据集,之后再讲图像返回到我们定义的img_transform,label_transform中。

    def __getitem__(self, item):
        img = cv2.imread(self.img_path[item],cv2.IMREAD_UNCHANGED)
        label = cv2.imread(self.label_path[item], cv2.IMREAD_UNCHANGED)

        return self.img_transform(img), self.label_transform(label)

第三个是__len__,这里我需要告诉程序运行到什么时候结束。

 def __len__(self):
        return len(self.img_path)

这样一个class Mydataset就重写完了

4.验证我们的读取的数据

可以使用pytorch中的Dataloader来读取Mydatase,然后看一下最后的输出

img_path = r'F:\data\GF2\1\img'
label_path = r'F:\data\GF2\1\label'
img_path = get_img_path(img_path)
label_path = get_img_path(label_path)
Dataloader = torch.utils.data.DataLoader(
    RSDataset(img_path, label_path, img_transform, label_transform),
    batch_size=2, shuffle=True, num_workers=0, pin_memory=True
)
for i, l in Dataloader:
    print(i.shape, l.shape)
    break


torch.Size([2, 4, 200, 200]) torch.Size([2, 200, 200])

  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Pytorch处理数据需要进行以下几个步骤: 1. 将文文本转换为数字序列,即进行分词和编码。可以使用jieba分词库对文文本进行分词,然后使用torchtext.vocab.Vocab将分词后的单词转换为数字。 2. 构建Dataset对象。可以使用torch.utils.data.Dataset来构建自己的数据集,需要实现__init__、__getitem__和__len__三个方法。 3. 将Dataset对象转换为DataLoader对象。可以使用torch.utils.data.DataLoaderDataset对象转换为DataLoader对象,以便进行批处理和数据增强等操作。 下面给出一个简单的文文本分的例子: ```python import jieba import torch from torch.utils.data import Dataset, DataLoader from torchtext.vocab import Vocab class ChineseTextDataset(Dataset): def __init__(self, data_path, vocab_path): self.data = [] self.vocab = Vocab.load(vocab_path) with open(data_path, "r", encoding="utf-8") as f: for line in f.readlines(): text, label = line.strip().split("\t") words = jieba.lcut(text) seq = torch.tensor([self.vocab.stoi[w] for w in words]) self.data.append((seq, int(label))) def __getitem__(self, idx): return self.data[idx] def __len__(self): return len(self.data) dataset = ChineseTextDataset("data.txt", "vocab.pkl") dataloader = DataLoader(dataset, batch_size=32, shuffle=True) ``` 其,data.txt是文文本和标签的数据文件,每行为一个样本,以tab分隔;vocab.pkl是使用torchtext.vocab.Vocab生成的词表文件。该例子使用jieba分词库对文文本进行分词,然后将分词后的单词转换为数字,并使用torch.utils.data.Dataset构建自己的数据集。最后,使用torch.utils.data.DataLoaderDataset对象转换为DataLoader对象,以便进行批处理和数据增强等操作。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值