# -*- coding: utf-8 -*-
import os, json
import cv2
from osgeo import gdal
import numpy as np
from osgeo import ogr, gdal, osr
from shapely.geometry import box, shape
import geopandas as gpd
from shapely.geometry.polygon import Polygon
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
# del dataset
return im_width, im_height, im_proj, im_geotrans, im_data, dataset
def write_img(filename,im_proj,im_geotrans,im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
def gdal_image_clip(inpath, outpath, new_width=500, stride=200):
test_im_dir = os.listdir(inpath)
for name in test_im_dir:
if name[-4:] == '.tif':
print("dealing the ",name," ...")
img = os.path.join(inpath, name)
im_width, im_height, im_proj, geo_transform, im_data, dataset = read_img(img)
# print(im_data.shape) # (3, 313, 480) 行 列 高宽
# print(geo_transform)
new_w = im_width
new_h = im_height
extent_data = im_data
# print(extent_data.shape)
count = 0
i = 0
num_ = 0
#filename = name[:-4]
filename, ext = os.path.splitext(name)
while i in range(new_h):
j=0
if (new_h-i) >=new_width:
while j in range(new_w):
if (new_w-j) >=new_width:
num_=num_+1
im_data_m=extent_data[:,i:i+new_width,j:j+new_width]
patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
# print(i,i+new_width,j,j+new_width)
# im_data_m = im_data_m.transpose(1,2,0)
# cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])
# write_img(patch_path, im_proj, im_geotrans, im_data_m)
xmin = j
ymin = i
xmax = j+new_width
ymax = i+new_width
x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]
x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])
write_img(patch_path, im_proj, new_geo_transform, im_data_m)
j=j+stride
if (new_w-j) <new_width:
num_=num_+1
im_data_m=extent_data[:,i:i+new_width,new_w-new_width:new_w]
patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
# im_data_m = im_data_m.transpose(1,2,0)
# cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])
xmin = new_w-new_width
ymin = i
xmax = new_w
ymax = i+new_width
x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]
x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])
write_img(patch_path, im_proj, new_geo_transform, im_data_m)
j=new_w+1
i=i+stride
else :
while j in range(new_w):
if (new_w-j) >=new_width:
num_=num_+1
im_data_m=extent_data[:,new_h-new_width:new_h,j:j+new_width]
patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
# im_data_m = im_data_m.transpose(1,2,0)
# cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])
xmin = j
ymin = new_h-new_width
xmax = j+new_width
ymax = new_h
x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]
x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])
write_img(patch_path, im_proj, new_geo_transform, im_data_m)
j=j+stride
if (new_w-j) <new_width:
num_=num_+1
im_data_m=extent_data[:,new_h-new_width:new_h,new_w-new_width:new_w]
patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
# im_data_m = im_data_m.transpose(1,2,0)
# cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])
xmin = new_w-new_width
ymin = new_h-new_width
xmax = new_w
ymax = new_h
x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]
x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])
write_img(patch_path, im_proj, new_geo_transform, im_data_m)
j=new_w+1
i=new_h+1
def gdal_image_clip_single(inpath, outpath, new_width, stride):
im_width, im_height, im_proj, geo_transform, im_data, dataset = read_img(inpath)
# print(im_data.shape) # (3, 313, 480) 行 列 高宽
# print(geo_transform)
new_w = im_width
new_h = im_height
extent_data = im_data
# print(extent_data.shape)
count = 0
i = 0
num_ = 0
name = os.path.split(inpath)[1]
#print(name)
#exit(0)
filename, ext = os.path.splitext(name)
while i in range(new_h):
j=0
if (new_h-i) >=new_width:
while j in range(new_w):
if (new_w-j) >=new_width:
num_=num_+1
im_data_m=extent_data[:,i:i+new_width,j:j+new_width]
patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
# print(i,i+new_width,j,j+new_width)
# im_data_m = im_data_m.transpose(1,2,0)
# cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])
# write_img(patch_path, im_proj, im_geotrans, im_data_m)
xmin = j
ymin = i
xmax = j+new_width
ymax = i+new_width
x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]
x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])
write_img(patch_path, im_proj, new_geo_transform, im_data_m)
j=j+stride
if (new_w-j) <new_width:
num_=num_+1
im_data_m=extent_data[:,i:i+new_width,new_w-new_width:new_w]
patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
# im_data_m = im_data_m.transpose(1,2,0)
# cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])
xmin = new_w-new_width
ymin = i
xmax = new_w
ymax = i+new_width
x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]
x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])
write_img(patch_path, im_proj, new_geo_transform, im_data_m)
j=new_w+1
i=i+stride
else :
while j in range(new_w):
if (new_w-j) >=new_width:
num_=num_+1
im_data_m=extent_data[:,new_h-new_width:new_h,j:j+new_width]
patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
# im_data_m = im_data_m.transpose(1,2,0)
# cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])
xmin = j
ymin = new_h-new_width
xmax = j+new_width
ymax = new_h
x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]
x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])
write_img(patch_path, im_proj, new_geo_transform, im_data_m)
j=j+stride
if (new_w-j) <new_width:
num_=num_+1
im_data_m=extent_data[:,new_h-new_width:new_h,new_w-new_width:new_w]
patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
# im_data_m = im_data_m.transpose(1,2,0)
# cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])
xmin = new_w-new_width
ymin = new_h-new_width
xmax = new_w
ymax = new_h
x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]
x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])
write_img(patch_path, im_proj, new_geo_transform, im_data_m)
j=new_w+1
i=new_h+1
def get_mask(img_path, out_shp):
im_width,im_height,im_proj,im_geotrans,im_data,dataset = read_img(img_path)
xleft = im_geotrans[0]
yleft = im_geotrans[3]
xright = im_geotrans[0] + im_width*im_geotrans[1] + im_height*im_geotrans[2]
yright = im_geotrans[3] + im_width*im_geotrans[4] + im_height*im_geotrans[5]
driver = ogr.GetDriverByName("ESRI Shapefile")
data_source = driver.CreateDataSource(out_shp)
srs = osr.SpatialReference(wkt=dataset.GetProjection())
layer = data_source.CreateLayer("polygon", srs, ogr.wkbPolygon)
feature = ogr.Feature(layer.GetLayerDefn())
wa = xleft
ha = yright
wa1 = xleft
ha1 = yleft
wa2 = xright
ha2 = yleft
wa3 = xright
ha3 = yright
wkt = "POLYGON((" + str(wa)+ " " +str(ha)+ "," + str(wa1) + " " + str(ha1) + "," + str(wa2)+ " " +str(ha2)+ "," + str(wa3)+ " " +str(ha3) + "))"
point = ogr.CreateGeometryFromWkt(wkt)
point.CloseRings() #这个必须要有,不然创建出来的矢量面有问题
feature.SetGeometry(point)
layer.CreateFeature(feature)
feature = None
data_source = None
def maskClipLabel(ds_large, layer_large, clip_shapefile, output_shapefile):
# # 输入大的Shapefile路径
# input_shapefile = "./MyProject/label.shp"
# # 输入小的Shapefile路径
# clip_shapefile = "./output/optic_1.shp"
# # 输出裁剪后的Shapefile路径
# output_shapefile = "./clipShp/optic_1_new.shp"
# 打开小的Shapefile
ds_small = ogr.Open(clip_shapefile)
layer_small = ds_small.GetLayer()
# 获取小的Shapefile的几何对象
feature_small = layer_small.GetFeature(0)
geometry_small = feature_small.GetGeometryRef()
# 获取大的Shapefile的坐标系统和投影信息
spatial_ref_large = layer_large.GetSpatialRef()
transform_large = osr.CoordinateTransformation(layer_small.GetSpatialRef(), spatial_ref_large)
# 创建输出Shapefile
driver = ogr.GetDriverByName("ESRI Shapefile")
ds_output = driver.CreateDataSource(output_shapefile)
layer_output = ds_output.CreateLayer("output_layer", spatial_ref_large, geom_type=ogr.wkbPolygon)
# 添加字段到输出Shapefile
layer_def = layer_large.GetLayerDefn()
for i in range(layer_def.GetFieldCount()):
field_def = layer_def.GetFieldDefn(i)
layer_output.CreateField(field_def)
# 遍历大的Shapefile的要素
for feature_large in layer_large:
geometry_large = feature_large.GetGeometryRef()
# 检查大的要素是否与小的要素相交
if geometry_large.Intersects(geometry_small):
# 获取相交部分
geometry_large.Transform(transform_large) # 将大的要素转换到小的Shapefile的坐标系统
intersection = geometry_large.Intersection(geometry_small)
# 创建新的要素并设置属性
feature_output = ogr.Feature(layer_output.GetLayerDefn())
feature_output.SetGeometry(intersection)
for i in range(layer_def.GetFieldCount()):
feature_output.SetField(layer_def.GetFieldDefn(i).GetNameRef(), feature_large.GetField(i))
# 将新的要素添加到输出Shapefile中
layer_output.CreateFeature(feature_output)
# 释放资源
ds_large = None
ds_small = None
ds_output = None
#del ds_small
#del ds_large
def overlay_features(source_shp, target_shp, output_shp):
# 打开源shp文件
source_ds = ogr.Open(source_shp)
if source_ds is None:
print("无法打开源shp文件")
return
# 打开目标shp文件
target_ds = ogr.Open(target_shp)
if target_ds is None:
print("无法打开目标shp文件")
source_ds = None
return
# 获取源shp的第一个图层
source_layer = source_ds.GetLayer()
# 获取目标shp的第一个图层
target_layer = target_ds.GetLayer()
# 创建输出shp文件,使用目标shp的空间参考信息
output_ds = ogr.GetDriverByName('ESRI Shapefile').CreateDataSource(output_shp)
# 获取目标图层的空间参考
output_srs = target_layer.GetSpatialRef()
# 创建新的图层时指定空间参考
output_layer = output_ds.CreateLayer(target_layer.GetName(), geom_type=target_layer.GetGeomType(), srs=output_srs)
# 获取源图层和目标图层的字段信息
source_defn = source_layer.GetLayerDefn()
target_defn = target_layer.GetLayerDefn()
# 复制目标图层的字段到输出图层
for i in range(target_defn.GetFieldCount()):
field_defn = target_defn.GetFieldDefn(i)
output_layer.CreateField(field_defn)
# 复制源图层的字段到输出图层,如果字段不存在则创建
for i in range(source_defn.GetFieldCount()):
field_defn = source_defn.GetFieldDefn(i)
field_name = field_defn.GetName()
# 检查是否已经存在同名字段
if output_layer.GetLayerDefn().GetFieldIndex(field_name) == -1:
output_layer.CreateField(field_defn)
# 设置坐标转换
transform = osr.CoordinateTransformation(source_layer.GetSpatialRef(), output_srs)
# 循环目标图层的要素
for target_feature in target_layer:
target_geometry = target_feature.GetGeometryRef()
# 创建新的输出要素
output_feature = ogr.Feature(output_layer.GetLayerDefn())
output_feature.SetGeometry(target_geometry)
# 复制目标要素的属性到输出要素
for i in range(target_feature.GetFieldCount()):
field_name = target_feature.GetFieldDefnRef(i).GetName()
field_value = target_feature.GetField(i)
output_feature.SetField(field_name, field_value)
# 将新要素添加到输出图层
output_layer.CreateFeature(output_feature)
# 重置源图层的读取位置
source_layer.ResetReading()
# 循环源图层的要素
for source_feature in source_layer:
source_geometry = source_feature.GetGeometryRef()
# 进行坐标转换
source_geometry.Transform(transform)
# 创建新的输出要素
output_feature = ogr.Feature(output_layer.GetLayerDefn())
output_feature.SetGeometry(source_geometry)
# 复制源要素的属性到输出要素
for i in range(source_feature.GetFieldCount()):
field_name = source_feature.GetFieldDefnRef(i).GetName()
field_value = source_feature.GetField(i)
output_feature.SetField(field_name, field_value)
# 将新要素添加到输出图层
output_layer.CreateFeature(output_feature)
# 关闭数据集
source_ds = None
target_ds = None
output_ds = None
if __name__ == "__main__":
tiff_folder = './convertdata/7bands-images_val/'
shpf_folder = './convertdata/7shp_val/'
#tiff_folder = './Train-7bands-images/'
#shpf_folder = './Train-shp-labels/'
for shpfile in os.listdir(shpf_folder):
if shpfile[-4:] == ".shp":
label_shp_path = os.path.join(shpf_folder, shpfile) # big shp to be clipped
shpfile_name, shpfile_ext = os.path.splitext(shpfile)
print('shpfile_name:', shpfile_name)
sitename, regionn, mark_ = shpfile_name.split('-')
if sitename == 'johula':
sitename = 'juhola'
tiffile_name = sitename + '_' + regionn + '_deno.tif' #for tiff with 7 bands
#tiffile_name = sitename + '-' + regionn + '-Drone.tif' # for drone tiff
print('tiffile_name:', tiffile_name)
big_img_path = os.path.join(tiff_folder, tiffile_name)
#big_img_path = "./Data-train/Train-7bands-images/jokisalo_region1_deno.tif" #大图
#label_shp_path = "./Data-train/Triain-shp-labels/jokisalo-region1-treelabel.shp" #大图对应的标签shp,可以和大图不一样大
clip_out_path1 = os.path.join("./convertdata/clip_image_val/", sitename)
clip_out_path = os.path.join(clip_out_path1, regionn) #存裁剪出的小图和对应的mask
if not os.path.exists(clip_out_path1):
os.mkdir(clip_out_path1)
if not os.path.exists(clip_out_path):
os.mkdir(clip_out_path)
gdal_image_clip_single(big_img_path, clip_out_path, new_width=64, stride=20)
#裁剪大图并生成小图的空mask
tifs = os.listdir(clip_out_path)
for ff in tifs:
imgname, ext = os.path.splitext(ff)
if ext == ".tif":
img_path = os.path.join(clip_out_path, ff)
out_shp = os.path.join(clip_out_path, imgname +'.shp')
get_mask(img_path, out_shp)
#小图的空mask去裁剪标签shp文件得到小图对应区域的标签文件,并将裁剪出的结果叠加到空mask上
save_clip_label1 = "./convertdata/clip_shp_val/" + sitename #存小图mask裁剪标签shp的结果即小图的标签
save_clip_label = os.path.join(save_clip_label1, regionn)
if not os.path.exists(save_clip_label1):
os.mkdir(save_clip_label1)
if not os.path.exists(save_clip_label):
os.mkdir(save_clip_label)
all_files = os.listdir(clip_out_path) #folder contains clipped raster images and shp files.
# open big shp to be clip
ds_large = ogr.Open(label_shp_path)
layer_large = ds_large.GetLayer()
for mask_shp in all_files:
shpname, ext = os.path.splitext(mask_shp)
if ext == ".shp":
mask_shp_path = os.path.join(clip_out_path, mask_shp)
clip_shp_path = os.path.join(save_clip_label, mask_shp) # save small shp
#save_overlay_shp = os.path.join(overlay_shp_path, mask_shp)
maskClipLabel(ds_large, layer_large, mask_shp_path, clip_shp_path)
#overlay_features(clip_shp_path, mask_shp_path, save_overlay_shp)
# 释放资源
ds_large = None
Clip image and shp
最新推荐文章于 2024-11-15 23:33:54 发布