mmsegmentation——RS_Inference

遥感影像批量预测与拼接


mmseg支持利用多线程进行遥感影像的滑动预测
在这里插入图片描述

函数调用示例

from mmseg.apis import init_model, inference_model,RSInferencer,RSImage
import torch
import time
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

config_path = r'config.py'
checkpoint_path = r'model.pth'
image_path = r'image.tif'
output = r'predict.tif'
'''
RSImage
Remote sensing image class.
Args:
     img (str or gdal.Dataset): Image file path or gdal.Dataset.
'''
data = RSImage(image_path)
batch_size = 8
model = init_model(config =config_path,checkpoint = checkpoint_path)
predicter = RSInferencer(model=model,batch_size=batch_size,thread=2)
# RSInferencer.from_config_path(config_path=config_path,checkpoint_path=checkpoint_path,thread=1,device=device)
start = time.time()
"""Run inference with multi-threading.
Args:
     image (RSImage): The image to inference.
     window_size (Tuple[int, int]): The size of the sliding window.
     strides (Tuple[int, int], optional): The stride of the sliding
         window. Defaults to (0, 0).
     output_path (Optional[str], optional): The path to save the
         segmentation map. Defaults to None.
"""
predicter.run(image = data,window_size = (384,384),strides = (128,128),output_path = output)
end = time.time()
all_time = end - start
print(f"The process cost : {all_time}")

- 注意 -
在这里插入图片描述

滑动预测优化(这部分代码存在问题)

遥感影像滑动预测过程,在边缘区域预测的准性较低,因此往往导致影像在拼接过程中出现影像接边问题;优化方式保留滑动预测结果与相邻影像中心区域,不准确的边缘区域舍弃。
在RS_Inference中,只需要为预测结果添加一个保留中心区域掩膜即可。
具体操作为在RSImage类中,重写write方法;我们将重写后的方法命名为write_mask

	def write_mask(self, data: Optional[np.ndarray], grid: Optional[List] = 	  None,mask_edge:Optional[List] = None):
        """Write image data.

        Args:
            grid (Optional[List], optional): Grid to write. Defaults to None.
            data (Optional[np.ndarray], optional): Data to write.
                Defaults to None.
            *grids:[
                    x_offset, y_offset, x_size, y_size, x_crop_off, y_crop_off,
                    x_crop_size, y_crop_size
                ]*
        Raises:
            ValueError: Either grid or data must be provided.
        """
        if grid is not None:
            assert len(grid) == 8, 'grid must be a list of 8 elements'
            for band in self.band_list:
                # print(f'{grid}\n')
                ###   根据格网参数设计掩膜  ###
                mask = np.zeros((grid[2],grid[3]))
                mask[grid[4]:grid[2]-grid[4],grid[5]:grid[3]-grid[5]] = 1
                data = data*mask
                band.WriteArray(
                    data[grid[5]:grid[5] + grid[7], grid[4]:grid[4] + grid[6]],
                    grid[0] + grid[4], grid[1] + grid[5])
        elif data is not None:
            for i in range(self.channel):
                self.band_list[i].WriteArray(data[..., i])
        else:
            raise ValueError('Either grid or data must be provided.')
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

云朵不吃雨

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值