python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化

30 篇文章 30 订阅

根据我前述博客中对图像传分割算法及图像块合并方法的实验探究,在此将这些方法用于遥感影像并尝试矢量化。
这个过程中我自己遇到了一个棘手的问题,在最后的结果那里有描述,希望知道的朋友帮忙解答一下,谢谢!
直接上代码:

# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import os
import cv2
from osgeo import ogr, osr, gdal
import numpy as np
from PIL import Image
from skimage import morphology, color, measure
from skimage.segmentation import felzenszwalb, slic, quickshift
from skimage.segmentation import mark_boundaries
from skimage.util import img_as_float
from skimage.future import graph
from skimage import data,filters
import matplotlib.pyplot as plt
from skimage.morphology import disk

def read_img(filename):
    dataset = gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0, 0, im_width, im_height)

    del dataset
    return im_width, im_height, im_proj, im_geotrans, im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1, im_data.shape

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)
    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i + 1).WriteArray(im_data[i])

    del dataset

def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
    out = np.zeros_like(bands).astype(np.float32)
    a = img_min
    b = img_max
    c = np.percentile(bands[:, :], lower_percent)
    d = np.percentile(bands[:, :], higher_percent)
    t = a + (bands[:, :] - c) * (b - a) / (d - c)
    t[t < a] = a
    t[t > b] = b
    out[:, :] = t
    return out

def DoesDriverHandleExtension(drv, ext):
    exts = drv.GetMetadataItem(gdal.DMD_EXTENSIONS)
    return exts is not None and exts.lower().find(ext.lower()) >= 0


def GetExtension(filename):
    ext = os.path.splitext(filename)[1]
    if ext.startswith('.'):
        ext = ext[1:]
    return ext


def GetOutputDriversFor(filename):
    drv_list = []
    ext = GetExtension(filename)
    for i in range(gdal.GetDriverCount()):
        drv = gdal.GetDriver(i)
        if (drv.GetMetadataItem(gdal.DCAP_CREATE) is not None or
            drv.GetMetadataItem(gdal.DCAP_CREATECOPY) is not None) and \
                drv.GetMetadataItem(gdal.DCAP_VECTOR) is not None:
            if ext and DoesDriverHandleExtension(drv, ext):
                drv_list.append(drv.ShortName)
            else:
                prefix = drv.GetMetadataItem(gdal.DMD_CONNECTION_PREFIX)
                if prefix is not None and filename.lower().startswith(prefix.lower()):
                    drv_list.append(drv.ShortName)

    return drv_list


def GetOutputDriverFor(filename):
    drv_list = GetOutputDriversFor(filename)
    ext = GetExtension(filename)
    if not drv_list:
        if not ext:
            return 'ESRI Shapefile'
        else:
            raise Exception("Cannot guess driver for %s" % filename)
    elif len(drv_list) > 1:
        print("Several drivers matching %s extension. Using %s" % (ext if ext else '', drv_list[0]))
    return drv_list[0]


def _weight_mean_color(graph, src, dst, n):
    """Callback to handle merging nodes by recomputing mean color.
    The method expects that the mean color of `dst` is already computed.
    Parameters
    ----------
    graph : RAG
        The graph under consideration.
    src, dst : int
        The vertices in `graph` to be merged.
    n : int
        A neighbor of `src` or `dst` or both.

    Returns
    -------
    data : dict
        A dictionary with the `"weight"` attribute set as the absolute
        difference of the mean color between node `dst` and `n`.
    """
    diff = graph.nodes[dst]['mean color'] - graph.nodes[n]['mean color']
    diff = np.linalg.norm(diff)
    return {'weight': diff}


def merge_mean_color(graph, src, dst):
    """Callback called before merging two nodes of a mean color distance graph.
    This method computes the mean color of `dst`.
    Parameters
    ----------
    graph : RAG
        The graph under consideration.
    src, dst : int
        The vertices in `graph` to be merged.
    """
    graph.nodes[dst]['total color'] += graph.nodes[src]['total color']
    graph.nodes[dst]['pixel count'] += graph.nodes[src]['pixel count']
    graph.nodes[dst]['mean color'] = (graph.nodes[dst]['total color'] /
                                      graph.nodes[dst]['pixel count'])

def BetterMedianFilter(src_arr, k = 3, padding = None):
    # imarray = np.array(Image.open(src))
    height, width = src_arr.shape
 
    if not padding:
        edge = int((k-1)/2)
        if height - 1 - edge <= edge or width - 1 - edge <= edge:
            print("The parameter k is to large.")
            return None
        new_arr = np.zeros((height, width), dtype = "uint16")
        for i in range(height):
            for j in range(width):
                if i <= edge - 1 or i >= height - 1 - edge or j <= edge - 1 or j >= height - edge - 1:
                    new_arr[i, j] = src_arr[i, j]
                else:
                    nm = src_arr[i - edge:i + edge + 1, j - edge:j + edge + 1]
                    max = np.max(nm)
                    min = np.min(nm)
                    if src_arr[i, j] == max or src_arr[i, j] == min:
                        new_arr[i, j] = np.median(nm)
                    else:
                        new_arr[i, j] = src_arr[i, j]

        return new_arr
 

if __name__ == '__main__':
    img_path = "./temp/test2.tif"
    temp_path = "./temp/"
    im_width, im_height, im_proj, im_geotrans, im_data = read_img(img_path)
    im_data = im_data[0:3]
    temp = im_data.transpose((2, 1, 0))
    segments_quick = quickshift(temp, kernel_size=3, max_dist=6, ratio=0.5)

    mark0 = mark_boundaries(temp, segments_quick)
    save_path = temp_path + "qs_seg0.tif"
    re0 = mark0.transpose((2, 1, 0))
    write_img(save_path, im_proj, im_geotrans, re0)

    grid_path = temp_path + "qs_grid0.tif"
    grid0 = np.uint8(re0[0, ...])
    write_img(grid_path, im_proj, im_geotrans, grid0)

    skeleton = morphology.skeletonize(grid0)
    border0 = np.multiply(grid0, skeleton)
    ret, border0 = cv2.threshold(border0, 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    border_path = temp_path + "qs_border0.tif"
    write_img(border_path, im_proj, im_geotrans, border0)

    g = graph.rag_mean_color(temp, segments_quick)
    labels2 = graph.merge_hierarchical(segments_quick, g, thresh=5,
                                       rag_copy=False,
                                       in_place_merge=True,
                                       merge_func=merge_mean_color,
                                       weight_func=_weight_mean_color)
    label_rgb2 = color.label2rgb(labels2, temp, kind='avg')

    rgb_path = temp_path + "qs_label.tif"
    lb = labels2.transpose((1, 0))
    write_img(rgb_path, im_proj, im_geotrans, lb)

    label_smooth = temp_path + "qs_label_smooth.tif"
    # lb = filters.median(lb, disk(5))
    lb = BetterMedianFilter(lb)
    write_img(label_smooth, im_proj, im_geotrans, lb)

    mark = mark_boundaries(label_rgb2, labels2)
    save_path = temp_path + "qs_seg.tif"
    re = mark.transpose((2, 1, 0))
    write_img(save_path, im_proj, im_geotrans, re)

    grid_path = temp_path + "qs_grid.tif"
    grid = np.uint8(re[0, ...])
    write_img(grid_path, im_proj, im_geotrans, grid)

    skeleton = morphology.skeletonize(grid)
    border = np.multiply(grid, skeleton)
    ret, border = cv2.threshold(border, 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    border_path = temp_path + "qs_border.tif"
    write_img(border_path, im_proj, im_geotrans, border)

    # out_shp = temp_path + "temp.shp"
    # RasterToLineshp(border_path, out_shp, 1)

    border_driver = gdal.Open(rgb_path)
    border_band = border_driver.GetRasterBand(1)
    border_mask = border_band.GetMaskBand()

    dst_filename = temp_path + 'temp.shp'
    frmt = GetOutputDriverFor(dst_filename)
    drv = ogr.GetDriverByName(frmt)
    dst_ds = drv.CreateDataSource(dst_filename)

    dst_layername = 'out'
    srs = osr.SpatialReference(wkt=border_driver.GetProjection())
    dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbPolygon, srs=srs)
    # dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbLineString, srs=srs)

    dst_fieldname = 'DN'
    fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
    dst_layer.CreateField(fd)
    dst_field = 0

    options = [""]
    options.append('DATASET_FOR_GEOREF=' + rgb_path)
    prog_func = gdal.TermProgress_nocb
    gdal.Polygonize(border_band, border_mask, dst_layer, dst_field, options,
                    callback=prog_func)

    srcband = None
    src_ds = None
    dst_ds = None
    mask_ds = None

# enum WKBGeometryType {
# wkbPoint = 1,
# wkbLineString = 2,
# wkbPolygon = 3,
# wkbTriangle = 17
# wkbMultiPoint = 4,
# wkbMultiLineString = 5,
# wkbMultiPolygon = 6,
# wkbGeometryCollection = 7,
# wkbPolyhedralSurface = 15,
# wkbTIN = 16
# wkbPointZ = 1001,
# wkbLineStringZ = 1002,
# wkbPolygonZ = 1003,
# wkbTrianglez = 1017
# wkbMultiPointZ = 1004,
# wkbMultiLineStringZ = 1005,
# wkbMultiPolygonZ = 1006,
# wkbGeometryCollectionZ = 1007,
# wkbPolyhedralSurfaceZ = 1015,
# wkbTINZ = 1016
# wkbPointM = 2001,
# wkbLineStringM = 2002,
# wkbPolygonM = 2003,
# wkbTriangleM = 2017
# wkbMultiPointM = 2004,
# wkbMultiLineStringM = 2005,
# wkbMultiPolygonM = 2006,
# wkbGeometryCollectionM = 2007,
# wkbPolyhedralSurfaceM = 2015,
# wkbTINM = 2016
# wkbPointZM = 3001,
# wkbLineStringZM = 3002,
# wkbPolygonZM = 3003,
# wkbTriangleZM = 3017
# wkbMultiPointZM = 3004,
# wkbMultiLineStringZM = 3005,
# wkbMultiPolygonZM = 3006,
# wkbGeometryCollectionZM = 3007,
# wkbPolyhedralSurfaceZM = 3015,
# wkbTinZM = 3016,
# }

对应的结果图如下:
原图:
原图
粗分割结果(代码中的qs_seg0.tif)
粗分割结果
粗分割格网(代码中的qs_grid0.tif)
粗分割格网
粗分割格网骨架(代码中的qs_border0.tif),格网的结果不是单线的,这里取了中心线。
粗分割格网骨架
合并后的分割结果(代码中的qs_seg.tif):
合并后的粗分割结果
合并后的格网结果(代码中的qs_grid.tif)
合并后的格网结果
合并后的格网骨架结果(代码中的qs_border.tif):
合并后的格网骨架结果
下面是矢量化以后的最终结果,这是代码中的qs_label.tif经过矢量化以后得到的结果,这里说明一下,之所以不用栅格线来直接转矢量线是因为我在GDAL里面并没有找到直接转化的方法,目前的方法强行转的话只能得到双线,完全不对,找了很久也没找到解决办法只能折中一下先得到面了,后面再面转线,看到的朋友如果知道的话烦请告知一下用什么办法可以直接把栅格线转为矢量线,要求脱离arcgis哈。
矢量化以后的结果

TO DO:
1.矢量面转线
2.线简化
3.线平滑
做完更新,感兴趣的朋友可以关注一下。

后续:
目前矢量面转矢量线肯定是没问题的,但是有个大问题就是矢量线的平滑对我来说还有一定难度,想不到具体高效的方式,唯一想到的方式就是将图层里的每一个节点找到,在节点位置不变的情况下取出节点之间的线条逐个平滑再放回到图层中,这样做有点慢,并且实现起来也比较复杂感觉,所以再次折中,我直接进行面的平滑,平滑完了再转线看看有没有可能对结果有帮助。
虽然不做线平滑了,下面还是先给出面转线的代码:

# -*- coding: utf-8 -*-
import os
import gdal
from osgeo import ogr,osr
import numpy as np

def Test_Poly2Line(input_poly,output_line):
    ogr.RegisterAll()
    
    driver = ogr.GetDriverByName('ESRI Shapefile')
    source_ds = driver.Open(input_poly,1)   
    source_layer = source_ds.GetLayer(0)

    # polygon2geometryCollection
    geomcol =  ogr.Geometry(ogr.wkbGeometryCollection)
    for feat in source_layer:
        geom = feat.GetGeometryRef()
        ring = geom.GetGeometryRef(0)
        geomcol.AddGeometry(ring)
        
    # geometryCollection2shp
    shpDriver = ogr.GetDriverByName("ESRI Shapefile")
    if os.path.exists(output_line):
            shpDriver.DeleteDataSource(output_line)
    outDataSource = shpDriver.CreateDataSource(output_line)
    outLayer = outDataSource.CreateLayer(output_line, geom_type=ogr.wkbMultiLineString)
    featureDefn = outLayer.GetLayerDefn()
    outFeature = ogr.Feature(featureDefn)
    outFeature.SetGeometry(geomcol)
    outLayer.CreateFeature(outFeature)
    outFeature = None


if __name__ == "__main__":
    poly_path = "E:/geo_test/temp/temp.shp"
    line_path = "E:/geo_test/temp/temp2line.shp"
    Test_Poly2Line(poly_path, line_path)

结果如下,可以看到这个结果和面完全保持一致,毕竟是gdal源码哈哈。
面专线

下面说一下在面未转为线的时候就平滑,在下面的位置加入了中值滤波,红线是针对8bit图的,一般不用,直接用打开的这个BetterMedianFilter就行了,参考链接https://blog.csdn.net/baidu_41902768/article/details/94451787
滤波

这是栅格面平滑后转化为面矢量的结果
面平滑
这是和之前没有进行平滑的结果的叠加对比,变化是有的,但是这里有一个大问题,就是锯齿状太严重。
对比

  • 12
    点赞
  • 69
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 22
    评论
对于Python遥感影像重采样,可以使用GDAL(Geospatial Data Abstraction Library)库来实现GDAL是一个开源的地理信息系统(GIS)库,它提供了许多用于处理栅格数据的功能,包括重采样。 下面是一个简单的示例代码,演示如何使用GDAL库进行遥感影像重采样: ```python from osgeo import gdal def resample_image(input_path, output_path, pixel_size): # 打开输入影像 input_ds = gdal.Open(input_path) # 获取输入影像的投影和仿射变换参数 projection = input_ds.GetProjection() geotransform = input_ds.GetGeoTransform() # 获取输入影像的宽度和高度 width = input_ds.RasterXSize height = input_ds.RasterYSize # 计算重采样后的宽度和高度 new_width = int(width / pixel_size) new_height = int(height / pixel_size) # 创建输出影像 driver = gdal.GetDriverByName('GTiff') output_ds = driver.Create(output_path, new_width, new_height, 1, gdal.GDT_Float32) # 设置输出影像的投影和仿射变换参数 output_ds.SetProjection(projection) output_ds.SetGeoTransform((geotransform[0], pixel_size, 0, geotransform[3], 0, -pixel_size)) # 执行重采样 gdal.ReprojectImage(input_ds, output_ds, None, None, gdal.GRA_Bilinear) # 关闭数据集 input_ds = None output_ds = None # 使用示例 input_path = 'input_image.tif' output_path = 'resampled_image.tif' pixel_size = 10 # 新的像素大小(单位:米) resample_image(input_path, output_path, pixel_size) ``` 在上面的示例中,`input_path`是输入影像的路径,`output_path`是重采样后的输出影像的路径,`pixel_size`是新的像素大小,用于指定重采样后每个像素的大小(单位:米)。代码将使用双线性插值进行重采样操作,并将结果保存为GeoTIFF格式的影像文件。 请注意,执行此代码需要安装GDAL库。你可以使用pip安装它:`pip install gdal`。 希望这个示例对你有帮助!如果你有任何其他问题,请随时提问。
评论 22
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

如雾如电

随缘

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值