python gdal完成arcgis分区统计功能(zonal)

30 篇文章 30 订阅

参考链接:https://towardsdatascience.com/zonal-statistics-algorithm-with-python-in-4-steps-382a3b66648a
注意事项,栅格需要和矢量的坐标系保持一致,结果存在了cvs文件中
结果
代码我稍微改了下,后续还会和矢量关联

import gdal
import ogr
import os
import numpy as np
import csv
import time

def boundingBoxToOffsets(bbox, geot):
    col1 = int((bbox[0] - geot[0]) / geot[1])
    col2 = int((bbox[1] - geot[0]) / geot[1]) + 1
    row1 = int((bbox[3] - geot[3]) / geot[5])
    row2 = int((bbox[2] - geot[3]) / geot[5]) + 1
    return [row1, row2, col1, col2]

def geotFromOffsets(row_offset, col_offset, geot):
    new_geot = [
    geot[0] + (col_offset * geot[1]),
    geot[1],
    0.0,
    geot[3] + (row_offset * geot[5]),
    0.0,
    geot[5]
    ]
    return new_geot

def setFeatureStats(fid, min, max, mean, median, sd, sum, count, names=["min", "max", "mean", "median", "sd", "sum", "count", "id"]):
    featstats = {
    names[0]: min,
    names[1]: max,
    names[2]: mean,
    names[3]: median,
    names[4]: sd,
    names[5]: sum,
    names[6]: count,
    names[7]: fid,
    }
    return featstats

def zonal(fn_raster, fn_zones, fn_csv):
    mem_driver = ogr.GetDriverByName("Memory")
    mem_driver_gdal = gdal.GetDriverByName("MEM")
    shp_name = "temp"

    # fn_raster = "C:/pyqgis/raster/USGS_NED_13_n45w116_IMG.img"
    # fn_zones = "C:/temp/zonal_stats/zones.shp"

    r_ds = gdal.Open(fn_raster)
    p_ds = ogr.Open(fn_zones)

    lyr = p_ds.GetLayer()
    geot = r_ds.GetGeoTransform()
    nodata = r_ds.GetRasterBand(1).GetNoDataValue()

    zstats = []

    p_feat = lyr.GetNextFeature()
    niter = 0

    while p_feat:
        if p_feat.GetGeometryRef() is not None:
            if os.path.exists(shp_name):
                mem_driver.DeleteDataSource(shp_name)
            tp_ds = mem_driver.CreateDataSource(shp_name)
            tp_lyr = tp_ds.CreateLayer('polygons', None, ogr.wkbPolygon)
            tp_lyr.CreateFeature(p_feat.Clone())
            offsets = boundingBoxToOffsets(p_feat.GetGeometryRef().GetEnvelope(),\
            geot)
            new_geot = geotFromOffsets(offsets[0], offsets[2], geot)

            tr_ds = mem_driver_gdal.Create(\
            "", \
            offsets[3] - offsets[2], \
            offsets[1] - offsets[0], \
            1, \
            gdal.GDT_Byte)

            tr_ds.SetGeoTransform(new_geot)
            gdal.RasterizeLayer(tr_ds, [1], tp_lyr, burn_values=[1])
            tr_array = tr_ds.ReadAsArray()

            r_array = r_ds.GetRasterBand(1).ReadAsArray(\
            offsets[2],\
            offsets[0],\
            offsets[3] - offsets[2],\
            offsets[1] - offsets[0])

            id = p_feat.GetFID()

            if r_array is not None:
                maskarray = np.ma.MaskedArray(\
                r_array,\
                maskarray=np.logical_or(r_array==nodata, np.logical_not(tr_array)))
                 
                if maskarray is not None:
                    zstats.append(setFeatureStats(\
                    id,\
                    maskarray.min(),\
                    maskarray.max(),\
                    maskarray.mean(),\
                    np.ma.median(maskarray),\
                    maskarray.std(),\
                    maskarray.sum(),\
                    maskarray.count()))
                else:
                    zstats.append(setFeatureStats(\
                    id,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata))
            else:
                zstats.append(setFeatureStats(\
                    id,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata))

            tp_ds = None
            tp_lyr = None
            tr_ds = None

            p_feat = lyr.GetNextFeature()

    # fn_csv = "C:/temp/zonal_stats/zstats.csv"
    col_names = zstats[0].keys()
    with open(fn_csv, 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, col_names)
        writer.writeheader()
        writer.writerows(zstats)

if __name__ == "__main__":
    time1 = time.time()
    fn_raster = './data/t_dem.tif'
    fn_zones = './data/grid1.shp'
    fn_csv = './data/zonsl.csv'
    zonal(fn_raster, fn_zones, fn_csv)
    time2 = time.time()
    print((time2-time1) / 3600.0)

如果只得到了csv而没有把值填进shp属性表那么上面操作的意义将大打折扣,下面是我完成以后的样子
对比图
统计结果
下面是更新后的代码:

import gdal
import ogr
import os
import numpy as np
import csv
import pandas as pd
import time

def boundingBoxToOffsets(bbox, geot):
    col1 = int((bbox[0] - geot[0]) / geot[1])
    col2 = int((bbox[1] - geot[0]) / geot[1]) + 1
    row1 = int((bbox[3] - geot[3]) / geot[5])
    row2 = int((bbox[2] - geot[3]) / geot[5]) + 1
    return [row1, row2, col1, col2]

def geotFromOffsets(row_offset, col_offset, geot):
    new_geot = [
    geot[0] + (col_offset * geot[1]),
    geot[1],
    0.0,
    geot[3] + (row_offset * geot[5]),
    0.0,
    geot[5]
    ]
    return new_geot

def setFeatureStats(fid, min, max, mean, median, sd, sum, count, names=["min", "max", "mean", "median", "sd", "sum", "count", "id"]):
    featstats = {
    names[0]: min,
    names[1]: max,
    names[2]: mean,
    names[3]: median,
    names[4]: sd,
    names[5]: sum,
    names[6]: count,
    names[7]: fid,
    }
    return featstats

def zonal(fn_raster, fn_zones, fn_csv):
    mem_driver = ogr.GetDriverByName("Memory")
    mem_driver_gdal = gdal.GetDriverByName("MEM")
    shp_name = "temp"

    # fn_raster = "C:/pyqgis/raster/USGS_NED_13_n45w116_IMG.img"
    # fn_zones = "C:/temp/zonal_stats/zones.shp"

    r_ds = gdal.Open(fn_raster)
    p_ds = ogr.Open(fn_zones)

    lyr = p_ds.GetLayer()
    geot = r_ds.GetGeoTransform()
    nodata = r_ds.GetRasterBand(1).GetNoDataValue()

    zstats = []

    p_feat = lyr.GetNextFeature()
    niter = 0

    while p_feat:
        if p_feat.GetGeometryRef() is not None:
            if os.path.exists(shp_name):
                mem_driver.DeleteDataSource(shp_name)
            tp_ds = mem_driver.CreateDataSource(shp_name)
            tp_lyr = tp_ds.CreateLayer('polygons', None, ogr.wkbPolygon)
            tp_lyr.CreateFeature(p_feat.Clone())
            offsets = boundingBoxToOffsets(p_feat.GetGeometryRef().GetEnvelope(),\
            geot)
            new_geot = geotFromOffsets(offsets[0], offsets[2], geot)

            tr_ds = mem_driver_gdal.Create(\
            "", \
            offsets[3] - offsets[2], \
            offsets[1] - offsets[0], \
            1, \
            gdal.GDT_Byte)

            tr_ds.SetGeoTransform(new_geot)
            gdal.RasterizeLayer(tr_ds, [1], tp_lyr, burn_values=[1])
            tr_array = tr_ds.ReadAsArray()

            r_array = r_ds.GetRasterBand(1).ReadAsArray(\
            offsets[2],\
            offsets[0],\
            offsets[3] - offsets[2],\
            offsets[1] - offsets[0])

            id = p_feat.GetFID()

            if r_array is not None:
                maskarray = np.ma.MaskedArray(\
                r_array,\
                maskarray=np.logical_or(r_array==nodata, np.logical_not(tr_array)))
                 
                if maskarray is not None:
                    zstats.append(setFeatureStats(\
                    id,\
                    maskarray.min(),\
                    maskarray.max(),\
                    maskarray.mean(),\
                    np.ma.median(maskarray),\
                    maskarray.std(),\
                    maskarray.sum(),\
                    maskarray.count()))
                else:
                    zstats.append(setFeatureStats(\
                    id,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata))
            else:
                zstats.append(setFeatureStats(\
                    id,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata))

            tp_ds = None
            tp_lyr = None
            tr_ds = None

            p_feat = lyr.GetNextFeature()
            
    col_names = zstats[0].keys()
    with open(fn_csv, 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, col_names)
        writer.writeheader()
        writer.writerows(zstats)

def shp_field_value(csv_file, shp):
    data = pd.DataFrame(pd.read_csv(csv_file))
    driver = ogr.GetDriverByName('ESRI Shapefile')
    layer_source = driver.Open(shp,1)
    lyr = layer_source.GetLayer()
    
    s_name = ogr.FieldDefn('min', ogr.OFTReal)
    lyr.CreateField(s_name)
    s_name = ogr.FieldDefn('max', ogr.OFTReal)
    lyr.CreateField(s_name)
    s_name = ogr.FieldDefn('mean', ogr.OFTReal)
    lyr.CreateField(s_name)
    s_name = ogr.FieldDefn('median', ogr.OFTReal)
    lyr.CreateField(s_name)
    s_name = ogr.FieldDefn('sd', ogr.OFTReal)
    lyr.CreateField(s_name)
    s_name = ogr.FieldDefn('sum', ogr.OFTReal)
    lyr.CreateField(s_name)
    s_name = ogr.FieldDefn('count', ogr.OFTReal)
    lyr.CreateField(s_name)

    count = 0
    defn = lyr.GetLayerDefn()
    featureCount = defn.GetFieldCount()
    feature = lyr.GetNextFeature()
    while feature is not None:
        for i in range(featureCount):
            feature.SetField('min', data['min'][count])
            feature.SetField('max', data['max'][count])
            feature.SetField('mean', data['mean'][count])
            feature.SetField('median', data['median'][count])
            feature.SetField('sd', data['sd'][count])
            feature.SetField('sum', data['sum'][count])
            feature.SetField('count', data['count'][count])
            lyr.SetFeature(feature)
        count+=1
        feature = lyr.GetNextFeature()

if __name__ == "__main__":
    time1 = time.time()
    fn_raster = './data/t_dem.tif'
    fn_zones = './data/grid1.shp'
    fn_csv = './data/zonsl.csv'
    zonal(fn_raster, fn_zones, fn_csv)
    shp_field_value(fn_csv, fn_zones)
    time2 = time.time()
    print((time2-time1) / 3600.0)

用时:0.00384633739789327 (h)

下面和arcgis的分区统计做下比较

import os
import time
import arcpy
from arcpy import env

def zonal(raster, shp):
	attri_table = "zonalstat"
	arcpy.gp.ZonalStatisticsAsTable_sa(shp, "FID", raster, attri_table, "NODATA", "ALL")
	arcpy.JoinField_management(shp,"FID", attri_table, "FID")

if __name__ == "__main__":
	time1 = time.time()
	dem_ras = './data/t_dem.tif'
	shp = './data/grid.shp'
	temp_path = './data/temp/'

	arcpy.CheckOutExtension('Spatial')
	arcpy.env.overwriteOutput = True
	env.workspace = temp_path

	zonal(dem_ras, shp)
	time2 = time.time()
	print((time2-time1) / 3600.0)

用时:0.196481666631 (h)
下面是arcgis统计以后的属性表,很明显arcgis多了几个字段包括AREA,MAJORITY,MINORITY等
arcgis
arcgis统计的字段比gdal实现的多,但是我感觉时间上还是有点太长了,就算gdal加上那些arcgis多出来的字段应该也不会这么长时间,感兴趣的朋友可以尝试一下。另外注意一下,两种方式得到的字段属性值也存在一定差异,但是大体上不会差很多,下面是mean值对应的视觉显示情况:
1.gdal
gdal

2.arcgis
arcgis
仔细观察下,可以发现gdal实现的效果还是有点瑕疵的,比如右下角的那几个异常值

测试数据链接:https://download.csdn.net/download/qq_20373723/13716488

  • 4
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 8
    评论
Python GDAL库是一个开源的地理数据抽象库。它提供了一种方便的方式来访问、读取和处理地理空间数据。GDAL库支持多种地理信息系统(GIS)格式,如Shapefile、GeoTIFF、KML等。 Python GDAL库的一个主要优势是它可以处理各种不同类型的地理数据并进行空间分析。它提供了强大的功能,如数据投影转换、裁剪、合并、重采样和地理空间分析等。 通过Python GDAL库,我们可以读取和写入地理矢量和栅格数据。例如,我们可以使用该库读取一个Shapefile文件,并将其转换为GeoJSON格式。我们还可以将一幅栅格图像裁剪为指定的区域,并保存为不同的格式。 Python GDAL库还可以进行地理空间分析。我们可以计算两个地理要素之间的距离,或者进行缓冲区分析,生成一定距离范围内的边界。此外,该库还支持地理要素之间的交叉、合并和裁剪等操作。 利用Python GDAL库,我们还可以进行地理数据的可视化。我们可以使用Matplotlib等可视化库将地理数据以图形的形式展示出来。这样可以更好地理解数据和展示结果。 总之,Python GDAL库是一个强大的工具,可用于读取、处理和分析各种地理空间数据。它提供了丰富的功能,同时易于使用,并且有大量的文档和示例代码可供参考。无论是进行地理数据处理、地理空间分析还是地理数据可视化,Python GDAL库都是一个不可或缺的工具。
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

如雾如电

随缘

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值