图像转矢量
我的路径比较深,你们可以直接调用主函数中的raster2LineShp()代码:
import time, glob, threading, os
from tqdm import trange
from osgeo import gdal, ogr, osr
from image2graph import *
def imagexy2geo(dataset, col, row):
'''
根据GDAL的六参数模型将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换)
:param dataset: GDAL地理数据
:param row: 像素的行号
:param col: 像素的列号
:return: 行列号(row, col)对应的投影坐标或地理坐标(x, y)
'''
trans = dataset.GetGeoTransform()
px = trans[0] + col * trans[1] + row * trans[2]
py = trans[3] + col * trans[4] + row * trans[5]
return px, py
def raster2LineShp(img_path, strVectorFile):
graph = generateGraph(img_path)
dataset = gdal.Open(img_path)
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "NO") # 为了支持中文路径
gdal.SetConfigOption("SHAPE_ENCODING", "CP936") # 为了使属性表字段支持中文
ogr.RegisterAll()
strDriverName = "ESRI Shapefile" # 创建数据,这里创建ESRI的shp文件
oDriver = ogr.GetDriverByName(strDriverName)
if oDriver == None:
print("%s 驱动不可用!\n", strDriverName)
oDS = oDriver.CreateDataSource(strVectorFile) # 创建数据源
if oDS == None:
print("创建文件【%s】失败!", strVectorFile)
# srs = osr.SpatialReference() # 创建空间参考
# srs.ImportFromEPSG(4326) # 定义地理坐标系WGS1984
srs = osr.SpatialReference(
wkt=dataset.GetProjection()) # 我在读栅格图的时候增加了输出dataset,这里就可以不用指定投影,实现全自动了,上面两行可以注释了,并且那个proj参数也可以去掉了,你们自己去掉吧
papszLCO = []
# 创建图层,创建一个多边形图层,"TestPolygon"->属性表名
oLayer = oDS.CreateLayer("TestPolygon", srs, ogr.wkbMultiLineString, papszLCO)
if oLayer == None:
print("图层创建失败!\n")
oDefn = oLayer.GetLayerDefn() # 定义要素
oFeatureTriangle = ogr.Feature(oDefn)
# 创建单个面
for n, v in graph.items():
for nei in v:
line = ogr.Geometry(ogr.wkbLinearRing) # 构建几何类型:线
nx, ny = n[1], n[0]
nx, ny = imagexy2geo(dataset, nx, ny)
line.AddPoint(nx, ny) # 添加点01
neix, neiy = nei[1], nei[0]
neix, neiy = imagexy2geo(dataset, neix, neiy)
line.AddPoint(neix, neiy) # 添加点02
oFeatureTriangle.SetGeometry(line)
oLayer.CreateFeature(oFeatureTriangle)
oDS.Destroy()
pre_root = r"C:\Users\Administrator\Desktop\Thin_gt\*\*//" #选择下一级文件夹
#################################################################################################
if __name__ == "__main__":
data_sat_dir = glob.glob(pre_root) #预测文件
start = time.time()
for next_dir in data_sat_dir:
img_list = os.listdir(next_dir)
for i in trange(len(img_list)):
pre = cv2.imread(next_dir+img_list[i], 0) #读取图片
length = next_dir.split("\\")[:7]
length[4] = "Thin2" #更改保存路径
save_skelet_pre_root = ''
for jj in length:
save_skelet_pre_root += jj +"\\"
os.makedirs(save_skelet_pre_root, exist_ok=True)
a = save_skelet_pre_root+img_list[i].split("\\")[-1][:-4]+".shp"
threading.Thread(target=raster2LineShp, args=(next_dir+img_list[i], save_skelet_pre_root+img_list[i].split("\\")[-1][:-4]+".shp")).start()
print("time:", time.time()-start)
image2graph.py
import skimage.morphology
from PIL import Image
import numpy
from math import sqrt
import cv2
def distance(a, b):
return sqrt((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2)
def point_line_distance(point, start, end):
if (start == end):
return distance(point, start)
else:
n = abs(
(end[0] - start[0]) * (start[1] - point[1]) - (start[0] - point[0]) * (end[1] - start[1])
)
d = sqrt(
(end[0] - start[0]) ** 2 + (end[1] - start[1]) ** 2
)
return n / d
def rdp(points, epsilon):
"""
Reduces a series of points to a simplified version that loses detail, but
maintains the general shape of the series.
"""
dmax = 0.0
index = 0
for i in range(1, len(points) - 1):
d = point_line_distance(points[i], points[0], points[-1])
if d > dmax:
index = i
dmax = d
if dmax >= epsilon:
results = rdp(points[:index + 1], epsilon)[:-1] + rdp(points[index:], epsilon)
else:
results = [points[0], points[-1]]
return results
def generateGraph(in_fname):
PADDING = 30
threshold = 1
out_fname = 'graph_gt.pickle'
im_data = cv2.imread(in_fname,0)
im = im_data
im = numpy.array(im)
if len(im.shape) == 3:
print('warning: bad shape {}, using first channel only'.format(im.shape))
im = im[:, :, 0]
im = numpy.swapaxes(im, 0, 1)
im = (im >= threshold)
Image.fromarray(im.astype('uint8') * 60).save("tmp0.png")
im = skimage.morphology.thin(im)
im = im.astype('uint8')
Image.fromarray(im * 255).save("tmp.png")
# extract a graph by placing vertices every THRESHOLD pixels, and at all intersections
vertices = []
edges = set()
def add_edge(src, dst):
if (src, dst) in edges or (dst, src) in edges:
return
elif src == dst:
return
edges.add((src, dst))
point_to_neighbors = {}
q = []
while True:
if len(q) > 0:
lastid, i, j = q.pop()
path = [vertices[lastid], (i, j)]
if im[i, j] == 0:
continue
point_to_neighbors[(i, j)].remove(lastid)
if len(point_to_neighbors[(i, j)]) == 0:
del point_to_neighbors[(i, j)]
else:
w = numpy.where(im > 0)
if len(w[0]) == 0:
break
i, j = w[0][0], w[1][0]
lastid = len(vertices)
vertices.append((i, j))
path = [(i, j)]
while True:
im[i, j] = 0
neighbors = []
for oi in [-1, 0, 1]:
for oj in [-1, 0, 1]:
ni = i + oi
nj = j + oj
if ni >= 0 and ni < im.shape[0] and nj >= 0 and nj < im.shape[1] and im[ni, nj] > 0:
neighbors.append((ni, nj))
if len(neighbors) == 1 and (i, j) not in point_to_neighbors:
ni, nj = neighbors[0]
path.append((ni, nj))
i, j = ni, nj
else:
if len(path) > 1:
path = rdp(path, 2)
if len(path) > 2:
for point in path[1:-1]:
curid = len(vertices)
vertices.append(point)
add_edge(lastid, curid)
lastid = curid
neighbor_count = len(neighbors) + len(point_to_neighbors.get((i, j), []))
if neighbor_count == 0 or neighbor_count >= 2:
curid = len(vertices)
vertices.append(path[-1])
add_edge(lastid, curid)
lastid = curid
for ni, nj in neighbors:
if (ni, nj) not in point_to_neighbors:
point_to_neighbors[(ni, nj)] = set()
point_to_neighbors[(ni, nj)].add(lastid)
q.append((lastid, ni, nj))
for neighborid in point_to_neighbors.get((i, j), []):
add_edge(neighborid, lastid)
break
neighbors = {}
vertex = vertices
for edge in edges:
nk1 = (vertex[edge[0]][1], vertex[edge[0]][0])
nk2 = (vertex[edge[1]][1], vertex[edge[1]][0])
if nk1 != nk2:
if nk1 in neighbors:
if nk2 in neighbors[nk1]:
pass
else:
neighbors[nk1].append(nk2)
else:
neighbors[nk1] = [nk2]
if nk2 in neighbors:
if nk1 in neighbors[nk2]:
pass
else:
neighbors[nk2].append(nk1)
else:
neighbors[nk2] = [nk1]
return neighbors