[PaddleSeg 源码阅读] PaddleSeg 自定义数据类

先感受一下怎么改:

Config 文件中的item与类的 __init__.py 参数一一对应

在 PaddleSeg 的 Config 文件中:

val_dataset:
  type: Dataset
  dataset_root: /root/share/program/save/data/portraint
  val_path: /root/share/program/save/data/portraint/txtfiles/test.txt
  num_classes: 2
  transforms:
    - type: Resize
      target_size: [224, 224]
    - type: Normalize
  mode: val

上边的 Config 参数就是传入 Dataset 类的 __init__ 中的参数

    def __init__(self,
                 transforms,
                 dataset_root,
                 num_classes,
                 mode='train',
                 train_path=None,
                 val_path=None,
                 test_path=None,
                 separator=' ',
                 ignore_index=255,
                 edge=False):
        self.dataset_root = dataset_root
        self.transforms = Compose(transforms)
        self.file_list = list()
        self.mode = mode.lower()
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.edge = edge
        
		......

如果这个 Dataset 不符合你的需求,则可以自定义一类,如 XXDataset

则将 Config 文件中的 Dataset type 部分改成 你自定义的类 XXDataset

该类需要在 paddleseg/datasets/__init__.py 中导入,可以直接在 __init__.py 中定义,也可以新建一个文件,然后在 __init__.py 中导入即可.

注意,在自定义数据类的开始,要放上这个装饰器,(要这样注册一下?)

@manager.DATASETS.add_component

(啊这,我我记得之前不用啊,或者说之前 PaddleClas 不用?

举个例子:

paddleseg/datasets/dataset.py 中:
在这里插入图片描述

config 文件:
在这里插入图片描述

__init__.py 中导入:
在这里插入图片描述

然后这里要注意一下:
在这里插入图片描述
如果不定义这两个属性,这里会报错,transforms 属性,按理说应该是参数传进来的,但是我这里是直接预处理的图片,所以给了 None 一般来说要传入的

所以我这里还得改一个地方,反正哪里报错,就改哪里就行

在这里插入图片描述

这里再补充一下,一般情况下,用传进来的 Transform 参数做预处理,不要像我这样:

class myDataset(Dataset):
    def __init__(self, data_root, 
                        img_name="images", 
                        ann_name="annotations"):
        
        self.img_path = os.path.join(data_root, img_name)
        self.ann_path = os.path.join(data_root, ann_name)
        
        # 务必传入相对路径,绝对路径 OpenCV 读不了
        self.imgs = [os.path.join(self.img_path, img) for img in os.listdir(self.img_path)]
        
        self.anns = [img.replace(".jpg", ".png") for img in self.imgs]
        self.anns = [img.replace(img_name, ann_name) for img in self.anns]
        

    def __len__(self):
        return len(self.anns)

    def __getitem__(self, idx):
        img = cv2.imread(self.imgs[idx])
        ann = cv2.imread(self.anns[idx], 0)
        
        img = self.img_preprocess(img)
        ann = self.ann_preprocess(ann)
        
        return img, ann
    
    @staticmethod
    def img_preprocess(img):
        img = cv2.resize(img, (224, 224)).astype(np.float32)
        img = img.transpose(2, 0, 1)
        img = img[None]
        img /= 255
        return img
    
    @staticmethod
    def ann_preprocess(ann):
        ann = cv2.resize(ann, (224, 224))
        return ann

我这里专门定义了 self.img_preprocessself.ann_preprocess 去读取图片,并做预处理.

所以我上边的代码,会把 self.transformNone,然后传入的参数给 []

这样也行,反正也懒得在 transforms 中注册一个新的 transform

但是,PaddleSeg 在 transforms 中做了一点儿偷懒的工作,会帮咱自动读入图片和Label

在这里插入图片描述
来看一下,到底是 transforms 中哪个类会读取图片:

在 Compose 的 __call__ 函数中:
在这里插入图片描述
如果传入的参数是 str 则其会读入对应的图

(注意,这里读进来的图片是 0-255 的,是uint8的)

(注意,for循环下的操作,是将所有的Transform做完,然后才进行的 Transpose)
也就是说,在所有的 Transform 之后,其输出的是 CHW, 也就是通道数在最前面
另外,最后一步的转置是 np 的操作,也就是说 PaddleSeg 的 Transform 大部分是基于 numpy 实现的(只有读取图片时,会用到 PIL.Image 类)


最后看看到底有哪些装饰器:
在这里插入图片描述

manager.MODELS
manager.BACKBONES
manager.DATASETS
manager.TRANSFORMS
manager.LOSSES

大概,后边加上 .add_component 后即可添加新对象

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值