Clip image and shp

# -*- coding: utf-8 -*-
import os, json
import cv2
from osgeo import gdal
import numpy as np
from osgeo import ogr, gdal, osr
from shapely.geometry import box, shape
import geopandas as gpd
from shapely.geometry.polygon import Polygon

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    # del dataset 
    return im_width, im_height, im_proj, im_geotrans, im_data, dataset

def write_img(filename,im_proj,im_geotrans,im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])


def gdal_image_clip(inpath, outpath, new_width=500, stride=200):
    test_im_dir = os.listdir(inpath)
    
    for name in test_im_dir:
        if name[-4:] == '.tif':
            print("dealing the ",name," ...")
            img = os.path.join(inpath, name)           
            im_width, im_height, im_proj, geo_transform, im_data, dataset = read_img(img)
            # print(im_data.shape) # (3, 313, 480) 行 列 高宽
            # print(geo_transform)
            new_w = im_width
            new_h = im_height
            extent_data = im_data
            # print(extent_data.shape)

            count = 0
            i = 0
            num_ = 0

            #filename = name[:-4]
            filename, ext = os.path.splitext(name)

            while i in range(new_h):
                j=0
                if (new_h-i) >=new_width:
                    while j in range(new_w):          
                        if (new_w-j) >=new_width:
                            num_=num_+1
                            im_data_m=extent_data[:,i:i+new_width,j:j+new_width]
                            patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
                            # print(i,i+new_width,j,j+new_width)
                            # im_data_m = im_data_m.transpose(1,2,0)
                            # cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])
                            # write_img(patch_path, im_proj, im_geotrans, im_data_m)

                            xmin = j
                            ymin = i
                            xmax = j+new_width
                            ymax = i+new_width

                            x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
                            y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]

                            x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
                            y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
                            new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])
                            write_img(patch_path, im_proj, new_geo_transform, im_data_m)
                            j=j+stride
                            
                        if (new_w-j) <new_width:
                            num_=num_+1
                            im_data_m=extent_data[:,i:i+new_width,new_w-new_width:new_w]
                            patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
                            # im_data_m = im_data_m.transpose(1,2,0)
                            # cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])

                            xmin = new_w-new_width
                            ymin = i
                            xmax = new_w
                            ymax = i+new_width

                            x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
                            y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]

                            x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
                            y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
                            new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])

                            write_img(patch_path, im_proj, new_geo_transform, im_data_m)
                            j=new_w+1
                            
                    i=i+stride

                else :
                    while j in range(new_w):
                        if  (new_w-j) >=new_width:
                            num_=num_+1
                            im_data_m=extent_data[:,new_h-new_width:new_h,j:j+new_width]
                            patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
                            # im_data_m = im_data_m.transpose(1,2,0)
                            # cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])

                            xmin = j
                            ymin = new_h-new_width
                            xmax = j+new_width
                            ymax = new_h

                            x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
                            y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]

                            x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
                            y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
                            new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])

                            write_img(patch_path, im_proj, new_geo_transform, im_data_m)
                            j=j+stride
                            
                        if (new_w-j) <new_width:
                            num_=num_+1
                            im_data_m=extent_data[:,new_h-new_width:new_h,new_w-new_width:new_w]
                            patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
                            # im_data_m = im_data_m.transpose(1,2,0)
                            # cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])

                            xmin = new_w-new_width
                            ymin = new_h-new_width
                            xmax = new_w
                            ymax = new_h

                            x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
                            y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]

                            x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
                            y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
                            new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])
                            write_img(patch_path, im_proj, new_geo_transform, im_data_m)
                            j=new_w+1
                            
                    i=new_h+1


def gdal_image_clip_single(inpath, outpath, new_width, stride):          
    im_width, im_height, im_proj, geo_transform, im_data, dataset = read_img(inpath)
    # print(im_data.shape) # (3, 313, 480) 行 列 高宽
    # print(geo_transform)
    new_w = im_width
    new_h = im_height
    extent_data = im_data
    # print(extent_data.shape)

    count = 0
    i = 0
    num_ = 0
    name = os.path.split(inpath)[1]
    #print(name)
    #exit(0)
    filename, ext = os.path.splitext(name)

    while i in range(new_h):
        j=0
        if (new_h-i) >=new_width:
            while j in range(new_w):          
                if (new_w-j) >=new_width:
                    num_=num_+1
                    im_data_m=extent_data[:,i:i+new_width,j:j+new_width]
                    patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
                    # print(i,i+new_width,j,j+new_width)
                    # im_data_m = im_data_m.transpose(1,2,0)
                    # cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])
                    # write_img(patch_path, im_proj, im_geotrans, im_data_m)

                    xmin = j
                    ymin = i
                    xmax = j+new_width
                    ymax = i+new_width

                    x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
                    y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]

                    x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
                    y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
                    new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])
                    write_img(patch_path, im_proj, new_geo_transform, im_data_m)
                    j=j+stride
                    
                if (new_w-j) <new_width:
                    num_=num_+1
                    im_data_m=extent_data[:,i:i+new_width,new_w-new_width:new_w]
                    patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
                    # im_data_m = im_data_m.transpose(1,2,0)
                    # cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])

                    xmin = new_w-new_width
                    ymin = i
                    xmax = new_w
                    ymax = i+new_width

                    x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
                    y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]

                    x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
                    y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
                    new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])

                    write_img(patch_path, im_proj, new_geo_transform, im_data_m)
                    j=new_w+1
                    
            i=i+stride

        else :
            while j in range(new_w):
                if  (new_w-j) >=new_width:
                    num_=num_+1
                    im_data_m=extent_data[:,new_h-new_width:new_h,j:j+new_width]
                    patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
                    # im_data_m = im_data_m.transpose(1,2,0)
                    # cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])

                    xmin = j
                    ymin = new_h-new_width
                    xmax = j+new_width
                    ymax = new_h

                    x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
                    y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]

                    x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
                    y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
                    new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])

                    write_img(patch_path, im_proj, new_geo_transform, im_data_m)
                    j=j+stride
                    
                if (new_w-j) <new_width:
                    num_=num_+1
                    im_data_m=extent_data[:,new_h-new_width:new_h,new_w-new_width:new_w]
                    patch_path = os.path.join(outpath, filename + '_' + str(num_) + '.tif')
                    # im_data_m = im_data_m.transpose(1,2,0)
                    # cv2.imwrite(patch_path, im_data_m, [int(cv2.cv2.IMWRITE_PNG_COMPRESSION),0])

                    xmin = new_w-new_width
                    ymin = new_h-new_width
                    xmax = new_w
                    ymax = new_h

                    x1 = geo_transform[0] + xmin * geo_transform[1] + ymin * geo_transform[2]
                    y1 = geo_transform[3] + xmin * geo_transform[4] + ymin * geo_transform[5]

                    x2 = geo_transform[0] + xmax * geo_transform[1] + ymax * geo_transform[2]
                    y2 = geo_transform[3] + xmax * geo_transform[4] + ymax * geo_transform[5]
                    new_geo_transform = (x1, geo_transform[1], geo_transform[2], y1, geo_transform[4], geo_transform[5])
                    write_img(patch_path, im_proj, new_geo_transform, im_data_m)
                    j=new_w+1
                    
            i=new_h+1


def get_mask(img_path, out_shp):
    im_width,im_height,im_proj,im_geotrans,im_data,dataset = read_img(img_path)
    xleft = im_geotrans[0]
    yleft = im_geotrans[3]
    xright = im_geotrans[0] + im_width*im_geotrans[1] + im_height*im_geotrans[2]
    yright = im_geotrans[3] + im_width*im_geotrans[4] + im_height*im_geotrans[5]
    driver = ogr.GetDriverByName("ESRI Shapefile")
    data_source = driver.CreateDataSource(out_shp)
    srs = osr.SpatialReference(wkt=dataset.GetProjection())
    layer = data_source.CreateLayer("polygon", srs, ogr.wkbPolygon)
    feature = ogr.Feature(layer.GetLayerDefn())
    wa = xleft
    ha = yright
    wa1 = xleft
    ha1 = yleft
    wa2 = xright
    ha2 = yleft
    wa3 = xright
    ha3 = yright
    wkt = "POLYGON((" + str(wa)+ " " +str(ha)+ "," + str(wa1) + " " + str(ha1) + "," + str(wa2)+ " " +str(ha2)+ "," + str(wa3)+ " " +str(ha3) + "))"
    point = ogr.CreateGeometryFromWkt(wkt)
    point.CloseRings()  #这个必须要有,不然创建出来的矢量面有问题
    feature.SetGeometry(point)
    layer.CreateFeature(feature)
    feature = None
    data_source = None


def maskClipLabel(ds_large, layer_large, clip_shapefile, output_shapefile):
    # # 输入大的Shapefile路径
    # input_shapefile = "./MyProject/label.shp"
    # # 输入小的Shapefile路径
    # clip_shapefile = "./output/optic_1.shp"
    # # 输出裁剪后的Shapefile路径
    # output_shapefile = "./clipShp/optic_1_new.shp"

    # 打开小的Shapefile
    ds_small = ogr.Open(clip_shapefile)
    layer_small = ds_small.GetLayer()

    # 获取小的Shapefile的几何对象
    feature_small = layer_small.GetFeature(0)
    geometry_small = feature_small.GetGeometryRef()

    # 获取大的Shapefile的坐标系统和投影信息
    spatial_ref_large = layer_large.GetSpatialRef()
    transform_large = osr.CoordinateTransformation(layer_small.GetSpatialRef(), spatial_ref_large)

    # 创建输出Shapefile
    driver = ogr.GetDriverByName("ESRI Shapefile")
    ds_output = driver.CreateDataSource(output_shapefile)
    layer_output = ds_output.CreateLayer("output_layer", spatial_ref_large, geom_type=ogr.wkbPolygon)

    # 添加字段到输出Shapefile
    layer_def = layer_large.GetLayerDefn()
    for i in range(layer_def.GetFieldCount()):
        field_def = layer_def.GetFieldDefn(i)
        layer_output.CreateField(field_def)

    # 遍历大的Shapefile的要素
    for feature_large in layer_large:
        geometry_large = feature_large.GetGeometryRef()

        # 检查大的要素是否与小的要素相交
        if geometry_large.Intersects(geometry_small):
            # 获取相交部分
            geometry_large.Transform(transform_large)  # 将大的要素转换到小的Shapefile的坐标系统
            intersection = geometry_large.Intersection(geometry_small)

            # 创建新的要素并设置属性
            feature_output = ogr.Feature(layer_output.GetLayerDefn())
            feature_output.SetGeometry(intersection)

            for i in range(layer_def.GetFieldCount()):
                feature_output.SetField(layer_def.GetFieldDefn(i).GetNameRef(), feature_large.GetField(i))

            # 将新的要素添加到输出Shapefile中
            layer_output.CreateFeature(feature_output)

    # 释放资源
    ds_large = None
    ds_small = None
    ds_output = None

    #del ds_small
    #del ds_large


def overlay_features(source_shp, target_shp, output_shp):
    # 打开源shp文件
    source_ds = ogr.Open(source_shp)
    if source_ds is None:
        print("无法打开源shp文件")
        return

    # 打开目标shp文件
    target_ds = ogr.Open(target_shp)
    if target_ds is None:
        print("无法打开目标shp文件")
        source_ds = None
        return

    # 获取源shp的第一个图层
    source_layer = source_ds.GetLayer()

    # 获取目标shp的第一个图层
    target_layer = target_ds.GetLayer()

    # 创建输出shp文件,使用目标shp的空间参考信息
    output_ds = ogr.GetDriverByName('ESRI Shapefile').CreateDataSource(output_shp)
    
    # 获取目标图层的空间参考
    output_srs = target_layer.GetSpatialRef()

    # 创建新的图层时指定空间参考
    output_layer = output_ds.CreateLayer(target_layer.GetName(), geom_type=target_layer.GetGeomType(), srs=output_srs)

    # 获取源图层和目标图层的字段信息
    source_defn = source_layer.GetLayerDefn()
    target_defn = target_layer.GetLayerDefn()

    # 复制目标图层的字段到输出图层
    for i in range(target_defn.GetFieldCount()):
        field_defn = target_defn.GetFieldDefn(i)
        output_layer.CreateField(field_defn)

    # 复制源图层的字段到输出图层,如果字段不存在则创建
    for i in range(source_defn.GetFieldCount()):
        field_defn = source_defn.GetFieldDefn(i)
        field_name = field_defn.GetName()

        # 检查是否已经存在同名字段
        if output_layer.GetLayerDefn().GetFieldIndex(field_name) == -1:
            output_layer.CreateField(field_defn)

    # 设置坐标转换
    transform = osr.CoordinateTransformation(source_layer.GetSpatialRef(), output_srs)

    # 循环目标图层的要素
    for target_feature in target_layer:
        target_geometry = target_feature.GetGeometryRef()

        # 创建新的输出要素
        output_feature = ogr.Feature(output_layer.GetLayerDefn())
        output_feature.SetGeometry(target_geometry)

        # 复制目标要素的属性到输出要素
        for i in range(target_feature.GetFieldCount()):
            field_name = target_feature.GetFieldDefnRef(i).GetName()
            field_value = target_feature.GetField(i)
            output_feature.SetField(field_name, field_value)

        # 将新要素添加到输出图层
        output_layer.CreateFeature(output_feature)

    # 重置源图层的读取位置
    source_layer.ResetReading()

    # 循环源图层的要素
    for source_feature in source_layer:
        source_geometry = source_feature.GetGeometryRef()

        # 进行坐标转换
        source_geometry.Transform(transform)

        # 创建新的输出要素
        output_feature = ogr.Feature(output_layer.GetLayerDefn())
        output_feature.SetGeometry(source_geometry)

        # 复制源要素的属性到输出要素
        for i in range(source_feature.GetFieldCount()):
            field_name = source_feature.GetFieldDefnRef(i).GetName()
            field_value = source_feature.GetField(i)
            output_feature.SetField(field_name, field_value)

        # 将新要素添加到输出图层
        output_layer.CreateFeature(output_feature)

    # 关闭数据集
    source_ds = None
    target_ds = None
    output_ds = None

if __name__ == "__main__":
     
    tiff_folder = './convertdata/7bands-images_val/'
    shpf_folder = './convertdata/7shp_val/'

    #tiff_folder = './Train-7bands-images/'
    #shpf_folder = './Train-shp-labels/'

    for shpfile in os.listdir(shpf_folder):

        if shpfile[-4:] == ".shp":
        
            label_shp_path = os.path.join(shpf_folder, shpfile) # big shp to be clipped

            shpfile_name, shpfile_ext = os.path.splitext(shpfile)
            print('shpfile_name:', shpfile_name)
            sitename, regionn, mark_ = shpfile_name.split('-')

            if sitename == 'johula':
                sitename = 'juhola'
                
            tiffile_name = sitename + '_' + regionn + '_deno.tif'  #for tiff with 7 bands
            #tiffile_name = sitename + '-' + regionn + '-Drone.tif' # for drone tiff
            print('tiffile_name:', tiffile_name)

            big_img_path = os.path.join(tiff_folder, tiffile_name)
            #big_img_path = "./Data-train/Train-7bands-images/jokisalo_region1_deno.tif"              #大图
            #label_shp_path = "./Data-train/Triain-shp-labels/jokisalo-region1-treelabel.shp"            #大图对应的标签shp,可以和大图不一样大
            
            clip_out_path1 = os.path.join("./convertdata/clip_image_val/", sitename)
            clip_out_path = os.path.join(clip_out_path1, regionn) #存裁剪出的小图和对应的mask
            
            if not os.path.exists(clip_out_path1):
                os.mkdir(clip_out_path1)

            if not os.path.exists(clip_out_path):
                os.mkdir(clip_out_path)

            gdal_image_clip_single(big_img_path, clip_out_path, new_width=64, stride=20)
 
            #裁剪大图并生成小图的空mask
            tifs = os.listdir(clip_out_path)
            for ff in tifs:
                imgname, ext = os.path.splitext(ff)
                if ext == ".tif":
                    img_path = os.path.join(clip_out_path, ff)
                    out_shp = os.path.join(clip_out_path, imgname +'.shp')
                    get_mask(img_path, out_shp)


            #小图的空mask去裁剪标签shp文件得到小图对应区域的标签文件,并将裁剪出的结果叠加到空mask上
            save_clip_label1 = "./convertdata/clip_shp_val/" +  sitename #存小图mask裁剪标签shp的结果即小图的标签
            save_clip_label = os.path.join(save_clip_label1, regionn)
            
            if not os.path.exists(save_clip_label1):
                os.mkdir(save_clip_label1)
            if not os.path.exists(save_clip_label):
                os.mkdir(save_clip_label)

            all_files = os.listdir(clip_out_path) #folder contains clipped raster images and shp files.

            # open big shp to be clip
            ds_large = ogr.Open(label_shp_path)          
            layer_large = ds_large.GetLayer()
            
            for mask_shp in all_files:
                shpname, ext = os.path.splitext(mask_shp)

                if ext == ".shp":
                    
                    mask_shp_path = os.path.join(clip_out_path, mask_shp)
                    clip_shp_path = os.path.join(save_clip_label, mask_shp) # save small shp
                    #save_overlay_shp = os.path.join(overlay_shp_path, mask_shp)
                    
                    maskClipLabel(ds_large, layer_large, mask_shp_path, clip_shp_path)
                    #overlay_features(clip_shp_path, mask_shp_path, save_overlay_shp)
            # 释放资源
            ds_large = None

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值