要生成矢量框需要将图像坐标转换为地理坐标或者投影坐标,以下代码是生成了满足条件的1000*1000区域对应的矢量框,关键在于红色字体部分。
# -*- coding: utf-8 -*-
import os
from osgeo import ogr, osr
import gdal
import heapq
import numba
import numpy as np
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
del dataset
return im_proj,im_geotrans,im_width, im_height,im_data
def write_img(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
@numba.jit
def conv2(X, k):
x_row, x_col = X.shape
k_row, k_col = k.shape
ret_row, ret_col = x_row - k_row + 1, x_col - k_col + 1
ret = np.empty((ret_row, ret_col))
for y in range(ret_row):
for x in range(ret_col):
sub = X[y : y + k_row, x : x + k_col]
ret[y,x] = np.sum(sub * k)
return ret
def imagexy2geo(dataset, row, col):
'''
根据GDAL的六参数模型将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换)
:param dataset: GDAL地理数据
:param row: 像素的行号
:param col: 像素的列号
:return: 行列号(row, col)对应的投影坐标或地理坐标(x, y)
'''
trans = dataset.GetGeoTransform()
px = trans[0] + col * trans[1] + row * trans[2]
py = trans[3] + col * trans[4] + row * trans[5]
return px, py
if __name__ == '__main__':
ogr.RegisterAll()
img_path = 'E:/wsl/pre/xm_15rgb.tif'
temp = 'E:/wsl/pre/xm_15rgb_temp.tif'
out = 'E:/wsl/shp/'
im_proj,im_geotrans,im_width, im_height,im_data = read_img(img_path)
im_data[im_data<50] = 0
im_data[im_data>250] = 0
write_img(temp, im_proj, im_geotrans, im_data)
print('save temp')
kernel = np.ones((1000,1000))
dataset=gdal.Open(temp)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
w_left = im_width%1000
h_left = im_height%1000
im_width_count = im_width/1000
im_height_count = im_height/1000
data_list = {}
for i in range(im_width_count):
for j in range(im_height_count):
data = dataset.ReadAsArray(i*1000, j*1000, 1000, 1000)
conv = np.sum(data * kernel)
data_list[str(i)+ "_" + str(j)] = conv
new_list = sorted(data_list,key=data_list.__getitem__, reverse=True)
# new_list = sorted(data_list,key=data_list.__getitem__)
count = 1
for ele in new_list[0:50]: #输出满足条件的前50个矢量框
w, h = ele.split('_')
w = int(w)*1000
h = int(h)*1000
wa, ha = imagexy2geo(dataset, h, w)
w1 = int(w) + 1000
h1 = int(h)
wa1, ha1 = imagexy2geo(dataset, h1, w1)
w2 = int(w) + 1000
h2 = int(h) + 1000
wa2, ha2 = imagexy2geo(dataset, h2, w2)
w3 = int(w)
h3 = int(h) + 1000
wa3, ha3 = imagexy2geo(dataset, h3, w3)
shp_path = os.path.join(out, str(count)+'.shp')
driver = ogr.GetDriverByName("ESRI Shapefile")
data_source = driver.CreateDataSource(shp_path)
srs = osr.SpatialReference()
srs.ImportFromEPSG(4326)
layer = data_source.CreateLayer("polygon", srs, ogr.wkbPolygon)
feature = ogr.Feature(layer.GetLayerDefn())
wkt = "POLYGON((" + str(wa)+ " " +str(ha)+ "," + str(wa1) + " " + str(ha1) + "," + str(wa2)+ " " +str(ha2)+ "," + str(wa3)+ " " +str(ha3) + "))"
point = ogr.CreateGeometryFromWkt(wkt)
point.CloseRings()
feature.SetGeometry(point)
layer.CreateFeature(feature)
feature = None
data_source = None
count += 1