##读取影像信息,新建函数
def read_image(input_path):
try:
gdal.SetConfigOption("GDAL_file_name_IS_UTF8", "YES") #用于处理gdal中文乱码
dataset = gdal.Open(input_path, 0)
if dataset is None:
print('could not open')
else:
im_width = dataset.RasterXSize # 栅格数据的宽度(栅格矩阵的列数)
im_height = dataset.RasterYSize # 栅格数据的高度(栅格矩阵的行数)
im_data = dataset.ReadAsArray(0, 0, im_width, im_height).astype(np.float32) # 将数据写成数组
im_geotrans = dataset.GetGeoTransform() # 栅格数据的六参数(仿射矩阵)
im_proj = dataset.GetProjection() # 栅格数据的投影(地图投影信息)
nodata = dataset.GetRasterBand(1).GetNoDataValue()
del dataset # 关闭对象
return im_width, im_height, im_data, im_geotrans, im_proj, nodata
except BaseException as e: #抛出异常的处理
print(str(e))
#写出影像数据
def write_img(im_data, im_geotrans, im_proj, nodata, output_path):
try:
gdal.SetConfigOption("GDAL_file_name_IS_UTF8", "YES")
# 判断栅格数据的数据类型
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
# 判读数组维数
if len(im_data.shape) <= 2:
im_bands, (im_height, im_width) = 1, im_data.shape
else:
im_bands, im_height, im_width = im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大内存空间
dataset = driver.Create(output_path, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
if im_bands == 1:
dataset.GetRasterBand(1).SetNoDataValue(nodata)
dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据
else:
for i in range(im_bands):
dataset.GetRasterBand(i + 1).SetNoDataValue(nodata)
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
except BaseException as e:
print(str(e))
#计算植被指数:
def GetNDVI(nir_data, red_data):
try:
denominator = np.array(nir_data + red_data, dtype=np.float32)
numerator = np.array(nir_data - red_data, dtype=np.float32)
nodata = np.full((nir_data.shape[0], nir_data.shape[1]), -999.0, dtype=np.float32)
ndvi = np.divide(numerator, denominator, out=nodata, where=denominator != 0.0)
# mask = np.greater(nir_data + red_data, 0.0)
# ndvi = np.choose(mask,(-999.0,(nir_data-red_data)*1.0/(nir_data+red_data)))
return ndvi
except BaseException as e:
print(str(e))
if __name__ == '__main__':
# 图像的输入路径以及输出路径
input_path = 'G:/*/*/*/0011_4band20170115.tif'
output_path = 'G:/*/*/*t/0011_4band20170115_NDVI.tif'
# 读取波段数据
im_width, im_height, im_data, im_geotrans, im_proj, nodata = read_image(input_path)
red_data = im_data[2]
nir_data = im_data[3]
ndvi_data = GetNDVI(nir_data, red_data)
# ndvi_result = np.where(ndvi_data > 1, -999.0, ndvi_data)
nodata = -999
# 写出NDVI数据
write_img(ndvi_data, im_geotrans, im_proj, nodata, output_path)
print("NDVI计算完成")