使用python-GDAL读取、输出遥感影像
from osgeo import gdal,osr
import numpy as np
def read_tif(path_in, need_info = True):
"""
input:
path_in:读取tif路径
need_info:是否返回坐标等信息
need_info = False 只返回数组
need_info = True 返回数组,坐标信息等
"""
rs_data = gdal.Open(path_in)
im_col = rs_data.RasterXSize
im_row = rs_data.RasterYSize
im_bands = rs_data.RasterCount
img_array = rs_data.ReadAsArray(0, 0, im_col, im_row).astype(np.uint8)
if need_info:
if im_bands > 1:
img_array = np.transpose(img_array, (1, 2, 0))
im_geotrans = rs_data.GetGeoTransform() # 读取仿射变换信息 六参数
im_proj = rs_data.GetProjection() # 读取栅格数据投影
left = im_geotrans[0] # 左上角横坐标
up = im_geotrans[3] # 左上角纵坐标
right = left + im_geotrans[1] * im_col + im_geotrans[2] * im_row # 右下角横坐标
bottom = up + im_geotrans[5] * im_row + im_geotrans[4] * im_col # 右下角纵坐标
extent = (left, right, bottom, up) # 左上角横坐标,右下角横坐标,左上角纵坐标,右下角纵坐标
# EPSG编码 例如 GCS_WGS_1984:4326 GCS_China_Geodetic_Coordinate_System_2000:4490
epsg_code = osr.SpatialReference(wkt = im_proj).GetAttrValue('AUTHORITY', 1)
# im_proj:投影信息 im_geotrans:坐标信息 epsg_code:EPSG编码 im_row:行 im_col:列 im_bands:通道
img_info = {'geoproj': im_proj, 'geotrans': im_geotrans,
'geosrs': epsg_code, 'row': im_row, 'col': im_col,
'bands': im_bands}
return img_array, img_info
else:
if im_bands > 1:
img_array = np.transpose(img_array, (1, 2, 0))
return img_array
def write_tif(im_data, ref_data_path, path_out):
"""
input:
im_data: H*W或者H*W*C 的数组
ref_data_path: 参考影像路径
path_out:输出tif路径
"""
im_data = np.squeeze(im_data)
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) >= 3:
im_data = np.transpose(im_data, (2, 0, 1))
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
rs_data = gdal.Open(ref_data_path) # 打开参考图像
im_geotrans = rs_data.GetGeoTransform() # 获取参考图像坐标信息
im_proj = rs_data.GetProjection() # 获取栅格数据的投影
epsg_code = osr.SpatialReference(wkt=im_proj).GetAttrValue('AUTHORITY', 1) #获取EPSG编码
driver = gdal.GetDriverByName("GTiff") #
dataset = driver.Create(path_out, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans) #设置仿射变换信息
srs = osr.SpatialReference()
srs.ImportFromEPSG(int(epsg_code))
dataset.SetProjection(srs.ExportToWkt())
if im_bands > 1:
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
else:
dataset.GetRasterBand(1).WriteArray(im_data)
del dataset
if __name__ == '__main__':
# 读取tif
img_array, img_info = read_tif(path_in, need_info = True) # 需要获取参考信息
img_array = read_tif(path_in, need_info = False) # 需要获取参考信息
# 输出tif im_data:待输出数组 ref_data_path:参考影像路径 path_out:结果输出路径
write_tif(im_data, ref_data_path, path_out)