根据我前述博客中对图像传分割算法及图像块合并方法的实验探究,在此将这些方法用于遥感影像并尝试矢量化。
这个过程中我自己遇到了一个棘手的问题,在最后的结果那里有描述,希望知道的朋友帮忙解答一下,谢谢!
直接上代码:
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import os
import cv2
from osgeo import ogr, osr, gdal
import numpy as np
from PIL import Image
from skimage import morphology, color, measure
from skimage.segmentation import felzenszwalb, slic, quickshift
from skimage.segmentation import mark_boundaries
from skimage.util import img_as_float
from skimage.future import graph
from skimage import data,filters
import matplotlib.pyplot as plt
from skimage.morphology import disk
def read_img(filename):
dataset = gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0, 0, im_width, im_height)
del dataset
return im_width, im_height, im_proj, im_geotrans, im_data
def write_img(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
out = np.zeros_like(bands).astype(np.float32)
a = img_min
b = img_max
c = np.percentile(bands[:, :], lower_percent)
d = np.percentile(bands[:, :], higher_percent)
t = a + (bands[:, :] - c) * (b - a) / (d - c)
t[t < a] = a
t[t > b] = b
out[:, :] = t
return out
def DoesDriverHandleExtension(drv, ext):
exts = drv.GetMetadataItem(gdal.DMD_EXTENSIONS)
return exts is not None and exts.lower().find(ext.lower()) >= 0
def GetExtension(filename):
ext = os.path.splitext(filename)[1]
if ext.startswith('.'):
ext = ext[1:]
return ext
def GetOutputDriversFor(filename):
drv_list = []
ext = GetExtension(filename)
for i in range(gdal.GetDriverCount()):
drv = gdal.GetDriver(i)
if (drv.GetMetadataItem(gdal.DCAP_CREATE) is not None or
drv.GetMetadataItem(gdal.DCAP_CREATECOPY) is not None) and \
drv.GetMetadataItem(gdal.DCAP_VECTOR) is not None:
if ext and DoesDriverHandleExtension(drv, ext):
drv_list.append(drv.ShortName)
else:
prefix = drv.GetMetadataItem(gdal.DMD_CONNECTION_PREFIX)
if prefix is not None and filename.lower().startswith(prefix.lower()):
drv_list.append(drv.ShortName)
return drv_list
def GetOutputDriverFor(filename):
drv_list = GetOutputDriversFor(filename)
ext = GetExtension(filename)
if not drv_list:
if not ext:
return 'ESRI Shapefile'
else:
raise Exception("Cannot guess driver for %s" % filename)
elif len(drv_list) > 1:
print("Several drivers matching %s extension. Using %s" % (ext if ext else '', drv_list[0]))
return drv_list[0]
def _weight_mean_color(graph, src, dst, n):
"""Callback to handle merging nodes by recomputing mean color.
The method expects that the mean color of `dst` is already computed.
Parameters
----------
graph : RAG
The graph under consideration.
src, dst : int
The vertices in `graph` to be merged.
n : int
A neighbor of `src` or `dst` or both.
Returns
-------
data : dict
A dictionary with the `"weight"` attribute set as the absolute
difference of the mean color between node `dst` and `n`.
"""
diff = graph.nodes[dst]['mean color'] - graph.nodes[n]['mean color']
diff = np.linalg.norm(diff)
return {'weight': diff}
def merge_mean_color(graph, src, dst):
"""Callback called before merging two nodes of a mean color distance graph.
This method computes the mean color of `dst`.
Parameters
----------
graph : RAG
The graph under consideration.
src, dst : int
The vertices in `graph` to be merged.
"""
graph.nodes[dst]['total color'] += graph.nodes[src]['total color']
graph.nodes[dst]['pixel count'] += graph.nodes[src]['pixel count']
graph.nodes[dst]['mean color'] = (graph.nodes[dst]['total color'] /
graph.nodes[dst]['pixel count'])
def BetterMedianFilter(src_arr, k = 3, padding = None):
# imarray = np.array(Image.open(src))
height, width = src_arr.shape
if not padding:
edge = int((k-1)/2)
if height - 1 - edge <= edge or width - 1 - edge <= edge:
print("The parameter k is to large.")
return None
new_arr = np.zeros((height, width), dtype = "uint16")
for i in range(height):
for j in range(width):
if i <= edge - 1 or i >= height - 1 - edge or j <= edge - 1 or j >= height - edge - 1:
new_arr[i, j] = src_arr[i, j]
else:
nm = src_arr[i - edge:i + edge + 1, j - edge:j + edge + 1]
max = np.max(nm)
min = np.min(nm)
if src_arr[i, j] == max or src_arr[i, j] == min:
new_arr[i, j] = np.median(nm)
else:
new_arr[i, j] = src_arr[i, j]
return new_arr
if __name__ == '__main__':
img_path = "./temp/test2.tif"
temp_path = "./temp/"
im_width, im_height, im_proj, im_geotrans, im_data = read_img(img_path)
im_data = im_data[0:3]
temp = im_data.transpose((2, 1, 0))
segments_quick = quickshift(temp, kernel_size=3, max_dist=6, ratio=0.5)
mark0 = mark_boundaries(temp, segments_quick)
save_path = temp_path + "qs_seg0.tif"
re0 = mark0.transpose((2, 1, 0))
write_img(save_path, im_proj, im_geotrans, re0)
grid_path = temp_path + "qs_grid0.tif"
grid0 = np.uint8(re0[0, ...])
write_img(grid_path, im_proj, im_geotrans, grid0)
skeleton = morphology.skeletonize(grid0)
border0 = np.multiply(grid0, skeleton)
ret, border0 = cv2.threshold(border0, 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
border_path = temp_path + "qs_border0.tif"
write_img(border_path, im_proj, im_geotrans, border0)
g = graph.rag_mean_color(temp, segments_quick)
labels2 = graph.merge_hierarchical(segments_quick, g, thresh=5,
rag_copy=False,
in_place_merge=True,
merge_func=merge_mean_color,
weight_func=_weight_mean_color)
label_rgb2 = color.label2rgb(labels2, temp, kind='avg')
rgb_path = temp_path + "qs_label.tif"
lb = labels2.transpose((1, 0))
write_img(rgb_path, im_proj, im_geotrans, lb)
label_smooth = temp_path + "qs_label_smooth.tif"
# lb = filters.median(lb, disk(5))
lb = BetterMedianFilter(lb)
write_img(label_smooth, im_proj, im_geotrans, lb)
mark = mark_boundaries(label_rgb2, labels2)
save_path = temp_path + "qs_seg.tif"
re = mark.transpose((2, 1, 0))
write_img(save_path, im_proj, im_geotrans, re)
grid_path = temp_path + "qs_grid.tif"
grid = np.uint8(re[0, ...])
write_img(grid_path, im_proj, im_geotrans, grid)
skeleton = morphology.skeletonize(grid)
border = np.multiply(grid, skeleton)
ret, border = cv2.threshold(border, 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
border_path = temp_path + "qs_border.tif"
write_img(border_path, im_proj, im_geotrans, border)
# out_shp = temp_path + "temp.shp"
# RasterToLineshp(border_path, out_shp, 1)
border_driver = gdal.Open(rgb_path)
border_band = border_driver.GetRasterBand(1)
border_mask = border_band.GetMaskBand()
dst_filename = temp_path + 'temp.shp'
frmt = GetOutputDriverFor(dst_filename)
drv = ogr.GetDriverByName(frmt)
dst_ds = drv.CreateDataSource(dst_filename)
dst_layername = 'out'
srs = osr.SpatialReference(wkt=border_driver.GetProjection())
dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbPolygon, srs=srs)
# dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbLineString, srs=srs)
dst_fieldname = 'DN'
fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
dst_layer.CreateField(fd)
dst_field = 0
options = [""]
options.append('DATASET_FOR_GEOREF=' + rgb_path)
prog_func = gdal.TermProgress_nocb
gdal.Polygonize(border_band, border_mask, dst_layer, dst_field, options,
callback=prog_func)
srcband = None
src_ds = None
dst_ds = None
mask_ds = None
# enum WKBGeometryType {
# wkbPoint = 1,
# wkbLineString = 2,
# wkbPolygon = 3,
# wkbTriangle = 17
# wkbMultiPoint = 4,
# wkbMultiLineString = 5,
# wkbMultiPolygon = 6,
# wkbGeometryCollection = 7,
# wkbPolyhedralSurface = 15,
# wkbTIN = 16
# wkbPointZ = 1001,
# wkbLineStringZ = 1002,
# wkbPolygonZ = 1003,
# wkbTrianglez = 1017
# wkbMultiPointZ = 1004,
# wkbMultiLineStringZ = 1005,
# wkbMultiPolygonZ = 1006,
# wkbGeometryCollectionZ = 1007,
# wkbPolyhedralSurfaceZ = 1015,
# wkbTINZ = 1016
# wkbPointM = 2001,
# wkbLineStringM = 2002,
# wkbPolygonM = 2003,
# wkbTriangleM = 2017
# wkbMultiPointM = 2004,
# wkbMultiLineStringM = 2005,
# wkbMultiPolygonM = 2006,
# wkbGeometryCollectionM = 2007,
# wkbPolyhedralSurfaceM = 2015,
# wkbTINM = 2016
# wkbPointZM = 3001,
# wkbLineStringZM = 3002,
# wkbPolygonZM = 3003,
# wkbTriangleZM = 3017
# wkbMultiPointZM = 3004,
# wkbMultiLineStringZM = 3005,
# wkbMultiPolygonZM = 3006,
# wkbGeometryCollectionZM = 3007,
# wkbPolyhedralSurfaceZM = 3015,
# wkbTinZM = 3016,
# }
对应的结果图如下:
原图:
粗分割结果(代码中的qs_seg0.tif)
粗分割格网(代码中的qs_grid0.tif)
粗分割格网骨架(代码中的qs_border0.tif),格网的结果不是单线的,这里取了中心线。
合并后的分割结果(代码中的qs_seg.tif):
合并后的格网结果(代码中的qs_grid.tif)
合并后的格网骨架结果(代码中的qs_border.tif):
下面是矢量化以后的最终结果,这是代码中的qs_label.tif经过矢量化以后得到的结果,这里说明一下,之所以不用栅格线来直接转矢量线是因为我在GDAL里面并没有找到直接转化的方法,目前的方法强行转的话只能得到双线,完全不对,找了很久也没找到解决办法只能折中一下先得到面了,后面再面转线,看到的朋友如果知道的话烦请告知一下用什么办法可以直接把栅格线转为矢量线,要求脱离arcgis哈。
TO DO:
1.矢量面转线
2.线简化
3.线平滑
做完更新,感兴趣的朋友可以关注一下。
后续:
目前矢量面转矢量线肯定是没问题的,但是有个大问题就是矢量线的平滑对我来说还有一定难度,想不到具体高效的方式,唯一想到的方式就是将图层里的每一个节点找到,在节点位置不变的情况下取出节点之间的线条逐个平滑再放回到图层中,这样做有点慢,并且实现起来也比较复杂感觉,所以再次折中,我直接进行面的平滑,平滑完了再转线看看有没有可能对结果有帮助。
虽然不做线平滑了,下面还是先给出面转线的代码:
# -*- coding: utf-8 -*-
import os
import gdal
from osgeo import ogr,osr
import numpy as np
def Test_Poly2Line(input_poly,output_line):
ogr.RegisterAll()
driver = ogr.GetDriverByName('ESRI Shapefile')
source_ds = driver.Open(input_poly,1)
source_layer = source_ds.GetLayer(0)
# polygon2geometryCollection
geomcol = ogr.Geometry(ogr.wkbGeometryCollection)
for feat in source_layer:
geom = feat.GetGeometryRef()
ring = geom.GetGeometryRef(0)
geomcol.AddGeometry(ring)
# geometryCollection2shp
shpDriver = ogr.GetDriverByName("ESRI Shapefile")
if os.path.exists(output_line):
shpDriver.DeleteDataSource(output_line)
outDataSource = shpDriver.CreateDataSource(output_line)
outLayer = outDataSource.CreateLayer(output_line, geom_type=ogr.wkbMultiLineString)
featureDefn = outLayer.GetLayerDefn()
outFeature = ogr.Feature(featureDefn)
outFeature.SetGeometry(geomcol)
outLayer.CreateFeature(outFeature)
outFeature = None
if __name__ == "__main__":
poly_path = "E:/geo_test/temp/temp.shp"
line_path = "E:/geo_test/temp/temp2line.shp"
Test_Poly2Line(poly_path, line_path)
结果如下,可以看到这个结果和面完全保持一致,毕竟是gdal源码哈哈。
下面说一下在面未转为线的时候就平滑,在下面的位置加入了中值滤波,红线是针对8bit图的,一般不用,直接用打开的这个BetterMedianFilter就行了,参考链接https://blog.csdn.net/baidu_41902768/article/details/94451787
这是栅格面平滑后转化为面矢量的结果
这是和之前没有进行平滑的结果的叠加对比,变化是有的,但是这里有一个大问题,就是锯齿状太严重。