遥感影像批量预测与拼接
mmseg支持利用多线程进行遥感影像的滑动预测
函数调用示例
from mmseg.apis import init_model, inference_model,RSInferencer,RSImage
import torch
import time
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
config_path = r'config.py'
checkpoint_path = r'model.pth'
image_path = r'image.tif'
output = r'predict.tif'
'''
RSImage
Remote sensing image class.
Args:
img (str or gdal.Dataset): Image file path or gdal.Dataset.
'''
data = RSImage(image_path)
batch_size = 8
model = init_model(config =config_path,checkpoint = checkpoint_path)
predicter = RSInferencer(model=model,batch_size=batch_size,thread=2)
# RSInferencer.from_config_path(config_path=config_path,checkpoint_path=checkpoint_path,thread=1,device=device)
start = time.time()
"""Run inference with multi-threading.
Args:
image (RSImage): The image to inference.
window_size (Tuple[int, int]): The size of the sliding window.
strides (Tuple[int, int], optional): The stride of the sliding
window. Defaults to (0, 0).
output_path (Optional[str], optional): The path to save the
segmentation map. Defaults to None.
"""
predicter.run(image = data,window_size = (384,384),strides = (128,128),output_path = output)
end = time.time()
all_time = end - start
print(f"The process cost : {all_time}")
- 注意 -
滑动预测优化(这部分代码存在问题)
遥感影像滑动预测过程,在边缘区域预测的准性较低,因此往往导致影像在拼接过程中出现影像接边问题;优化方式保留滑动预测结果与相邻影像中心区域,不准确的边缘区域舍弃。
在RS_Inference中,只需要为预测结果添加一个保留中心区域掩膜即可。
具体操作为在RSImage类中,重写write方法;我们将重写后的方法命名为write_mask
def write_mask(self, data: Optional[np.ndarray], grid: Optional[List] = None,mask_edge:Optional[List] = None):
"""Write image data.
Args:
grid (Optional[List], optional): Grid to write. Defaults to None.
data (Optional[np.ndarray], optional): Data to write.
Defaults to None.
*grids:[
x_offset, y_offset, x_size, y_size, x_crop_off, y_crop_off,
x_crop_size, y_crop_size
]*
Raises:
ValueError: Either grid or data must be provided.
"""
if grid is not None:
assert len(grid) == 8, 'grid must be a list of 8 elements'
for band in self.band_list:
# print(f'{grid}\n')
### 根据格网参数设计掩膜 ###
mask = np.zeros((grid[2],grid[3]))
mask[grid[4]:grid[2]-grid[4],grid[5]:grid[3]-grid[5]] = 1
data = data*mask
band.WriteArray(
data[grid[5]:grid[5] + grid[7], grid[4]:grid[4] + grid[6]],
grid[0] + grid[4], grid[1] + grid[5])
elif data is not None:
for i in range(self.channel):
self.band_list[i].WriteArray(data[..., i])
else:
raise ValueError('Either grid or data must be provided.')