遥感影像三波段获取其中一个波段代码

本文介绍了一个使用Python和GDAL库的方法,用于从遥感影像中获取特定波段的数据,并将其转换为单波段文件,同时处理了栅格数据读取、创建和窗口滑动处理的技术细节。
摘要由CSDN通过智能技术生成

遥感影像三波段获取其中一个波段代码

# -*- coding: utf-8 -*-
from osgeo import gdal
from osgeo import ogr, osr
import os, sys,shutil
import numpy as np
import datetime

class Three2one:

    def __init__(self):

        # 为了支持中文路径,请添加下面这句代码
        gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
        # 为了使属性表字段支持中文,请添加下面这句
        gdal.SetConfigOption("SHAPE_ENCODING", "CP936")

        ogr.RegisterAll()
        gdal.AllRegister()

        return

     # 读取栅格************
    def read_raster(self, RasterPath):
        '''
        :param RasterPath: input the raster
        :return:  data--- the values of the raster
                  SpacialRef -- include the geotransform, projection and the NoDate of the raster
        '''
        dataset = gdal.Open(RasterPath, gdal.GA_ReadOnly)

        if not dataset:
            print('打开文件失败')
        XSize = dataset.RasterXSize  # 影像列数
        YSize = dataset.RasterYSize  # 影像行数
        band_num = dataset.RasterCount  # 波段数
        datatype = dataset.GetRasterBand(1).DataType

        geotransform = dataset.GetGeoTransform()  # 仿射矩阵
        projection = dataset.GetProjection()  # 投影信息

        #  获取波段及数据
        if band_num == 1:
            band = dataset.GetRasterBand(1)
            data_type = band.DataType
            data = band.ReadAsArray(0, 0, XSize, YSize)
            NoDate = band.GetNoDataValue()
        else:
            data = []
            for i in range(band_num):
                band = dataset.GetRasterBand(i + 1)
                dt = band.ReadAsArray(0, 0, XSize, YSize)
                data.append(list(dt))
                if i == 0:
                    NoDate = band.GetNoDataValue()
            data = np.array(data)

        SpacialRef = [geotransform, projection, NoDate]

        dataset = None

        return data, SpacialRef

    # 创建栅格************
    def Creat_raster(self, RasterCreatPath, ArrayDate, SpacialRef, ctmap=None, DriverName="GTiff"):

        # NoData_value = float('nan')
        if 'int8' in ArrayDate.dtype.name:
            datatype = gdal.GDT_Byte
        elif 'int16' in ArrayDate.dtype.name:
            datatype = gdal.GDT_UInt16
        elif 'int' in ArrayDate.dtype.name:
            datatype = gdal.GDT_Int32
        else:
            datatype = gdal.GDT_Float32

        if len(ArrayDate.shape) == 3:
            Bandnum, YSize, XSize = ArrayDate.shape
        else:
            Bandnum = 1
            [YSize, XSize] = ArrayDate.shape

        geotransform = SpacialRef[0]
        projection = SpacialRef[1]
        NoData_value = SpacialRef[2]

        if NoData_value is None:
            NoData_value = float('nan')

        # 设置 LZW 压缩
        # options = ['COMPRESS=LZW']
        # dst_ds = driver.CreateCopy(outtif, src_ds, 0, options=options)

        driver = gdal.GetDriverByName(DriverName)
        dataset = driver.Create(RasterCreatPath, XSize, YSize, Bandnum, datatype)

        if ctmap is not None:
            color = gdal.ColorTable()
            for i in range(len(ctmap)):
                color.SetColorEntry(ctmap[i][0], (ctmap[i][1], ctmap[i][2], ctmap[i][3], ctmap[i][4]))

        if DriverName == "GTiff":
            if geotransform is not None:
                dataset.SetGeoTransform(geotransform)  # 写入仿射变换参数

            if projection is not None:
                dataset.SetProjection(projection)

        if len(ArrayDate.shape) == 3:
            for i in range(Bandnum):
                dataset.GetRasterBand(i + 1).WriteArray(ArrayDate[i])  # 写入数组数据
                dataset.GetRasterBand(i + 1).SetNoDataValue(NoData_value)
                dataset.GetRasterBand(i + 1).ComputeStatistics(True)
                print('波段', i, '写入完成')
            dataset = None
        else:
            band = dataset.GetRasterBand(1)
            band.WriteArray(ArrayDate)  # 写入数组数据
            band.SetNoDataValue(NoData_value)  # 设置无值区域
            if DriverName == "GTiff":
                band.ComputeStatistics(True)
            if ctmap is not None:
                band.SetRasterColorTable(color)

            # 建立输出图像的金字塔
            # gdal.SetConfigOption('HFA_USE_RRD', 'YES')
            # dataset.BuildOverviews(overviewlist=[2, 4, 8, 16])  # 4层
            dataset = None

        return

    def run(self, InputRaster, OutputRaster):

        data, SpacialRef = self.read_raster(InputRaster)
        dd_out = data[0]
        self.Creat_raster(OutputRaster, dd_out.astype('uint8'), SpacialRef)

        print('successful')

        return

    def move_index(self, height, width, block, overlap_rate):

        # 滑窗大小
        slide_window_size = block

        whsize = height if height <= width else width

        if slide_window_size > whsize:
            slide_window_size = whsize

        # 滑框的重叠率
        overlap_pixel = int(slide_window_size * (1 - overlap_rate))
        # overlap_pixel = int(slide_window_size * overlap_rate)

        # ------------------------------------------------------------------#
        #                处理图像各个维度尺寸过小的情况。
        # ------------------------------------------------------------------#
        if height - slide_window_size < 0:  # 判断y是否超边界,为真则表示超边界
            y_idx = [0]
            nYBK = [height]
        else:
            y_idx = [x for x in range(0, height - slide_window_size + 1, overlap_pixel)]
            nYBK = [slide_window_size for x in range(0, height - slide_window_size + 1, overlap_pixel)]
            if y_idx[-1] + slide_window_size > height:

                if overlap_rate == 0.0:
                    y_idx[-1] = y_idx[-2] + slide_window_size
                    nYBK[-1] = height - y_idx[-2]
                else:
                    y_idx[-1] = height - slide_window_size

            else:
                if overlap_rate == 0.0:
                    y_idx.append(y_idx[-1] + slide_window_size)
                    nYBK.append(height - y_idx[-1])
                else:
                    y_idx.append(height - slide_window_size)
                    nYBK.append(slide_window_size)


        if width - slide_window_size < 0:  # 判断x是否超边界,为真则表示超边界
            x_idx = [0]
            nXBK = [width]
        else:
            x_idx = [y for y in range(0, width - slide_window_size + 1, overlap_pixel)]
            nXBK = [slide_window_size for y in range(0, width - slide_window_size + 1, overlap_pixel)]
            if x_idx[-1] + slide_window_size > width:

                if overlap_rate == 0.0:
                    x_idx[-1] = x_idx[-2] + slide_window_size
                    nXBK[-1] = width - x_idx[-2]
                else:
                    x_idx[-1] = width - slide_window_size

            else:
                if overlap_rate == 0.0:
                    x_idx.append(x_idx[-1] + slide_window_size)
                    nXBK.append(width - x_idx[-1])
                else:
                    x_idx.append(width - slide_window_size)
                    nXBK.append(slide_window_size)


        return x_idx, y_idx,nXBK, nYBK

    def run_block(self, InputRaster, OutputRaster):

        dataset = gdal.Open(InputRaster, gdal.GA_ReadOnly)

        if not dataset:
            print('打开文件失败')
        XSize = dataset.RasterXSize  # 影像列数
        YSize = dataset.RasterYSize  # 影像行数
        band_num = dataset.RasterCount  # 波段数
        datatype = dataset.GetRasterBand(1).DataType
        nodata = dataset.GetRasterBand(1).GetNoDataValue()

        geotransform = dataset.GetGeoTransform()  # 仿射矩阵
        projection = dataset.GetProjection()  # 投影信息

        driver = gdal.GetDriverByName("GTiff")
        dataset_new = driver.Create(OutputRaster, XSize, YSize, 1, datatype)

        if geotransform is not None:
            dataset_new.SetGeoTransform(geotransform)  # 写入仿射变换参数

        if projection is not None:
            dataset_new.SetProjection(projection)

        block = 1024
        overlap_rate = 0.0
        x_idx, y_idx, nXBK, nYBK = self.move_index(YSize, XSize, block, overlap_rate)

        band = dataset.GetRasterBand(1)
        band_out = dataset_new.GetRasterBand(1)
        for iy, y_start in enumerate(y_idx):
            for ix, x_start in enumerate(x_idx):
                data = band.ReadAsArray(x_start, y_start, nXBK[ix], nYBK[iy])
                band_out.WriteArray(data, x_start, y_start)
            print('第%d行处理完成。'%y_start)
        if nodata is not None:
            band_out.SetNoDataValue(nodata)  # 设置无值区域

        dataset_new = None
        dataset = None
        return

if __name__ == '__main__':

    InputRaster = r'./变化pred.tif'
    OutputRaster = r'./变化pred_single.tif'

    Three2one().run_block(InputRaster,OutputRaster)

  • 7
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值