python 16bit转8bit的方法

30 篇文章 30 订阅
import os
import numpy as np
import gdal

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset


def write_img_(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, gdal.GDT_Byte)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset


def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
    out = np.zeros_like(bands).astype(np.float32)
    # a = 0
    # b = 65535
    a = img_min
    b = img_max
    # print(a, b)
    c = np.percentile(bands[:, :], lower_percent)
    d = np.percentile(bands[:, :], higher_percent)
    # x = d-c
    # if (x==0).any():
    #     t = 0
    # else:
    t = a + (bands[:, :] - c) * (b - a) / (d - c)
    t[t < a] = a
    t[t > b] = b
    out[:, :] = t
    return out

def getTifSize(tif):
    dataSet = gdal.Open(tif)
    width = dataSet.RasterXSize
    height = dataSet.RasterYSize
    bands = dataSet.RasterCount
    geoTrans = dataSet.GetGeoTransform()
    proj = dataSet.GetProjection()
    return width,height,bands,geoTrans,proj


def partDivisionForBoundary(tif1,divisionSize,tempPath):
    width,height,bands,geoTrans,proj = getTifSize(tif1)
    partWidth = partHeight = divisionSize

    if width % partWidth > 0 :
        widthNum = width // partWidth + 1
    else:
        widthNum =  width // partWidth
    if height % partHeight > 0:
        heightNum = height // partHeight +1
    else:
        heightNum = height // partHeight

    realName = os.path.split(tif1)[1].split(".")[0]

    tif1 = gdal.Open(tif1)
    im_data = tif1.ReadAsArray(0,0,width,height)
    min_16bit = np.min(im_data)
    max_16bit = np.max(im_data)
    for i in range(heightNum):
        for j in range(widthNum):

            startX = partWidth * j
            startY = partHeight * i

            if startX+partWidth<= width and startY+partHeight<=height:
                realPartWidth = partWidth
                realPartHeight = partHeight
            elif startX + partWidth > width and startY+partHeight<=height:
                realPartWidth = width - startX
                realPartHeight = partHeight
            elif startX+partWidth <= width  and startY+partHeight > height:
                realPartWidth = partWidth
                realPartHeight = height - startY
            elif startX + partWidth > width and startY+partHeight > height:
                realPartWidth = width - startX
                realPartHeight = height - startY

            outName = realName + str(i)+str(j)+".tif"
            outPath = os.path.join(tempPath,outName)

            if not os.path.exists(outPath):

                driver = gdal.GetDriverByName("GTiff")
                outTif = driver.Create(outPath,realPartWidth,realPartHeight, bands, gdal.GDT_Byte)
                outTif.SetGeoTransform(geoTrans)
                outTif.SetProjection(proj)

                data1 = tif1.ReadAsArray(startX,startY,realPartWidth,realPartHeight)
                # type 1
                image_8bit = np.array(np.rint(255 * ((data1 - min_16bit) / (max_16bit - min_16bit))), dtype=np.uint8)
                # type 2
                #image_8bit = int8(data1*256/65536)

                
                if bands == 1:
                    outTif.GetRasterBand(1).WriteArray(image_8bit[0])
                elif bands == 4:
                    outTif.GetRasterBand(1).WriteArray(image_8bit[0])
                    outTif.GetRasterBand(2).WriteArray(image_8bit[1])
                    outTif.GetRasterBand(3).WriteArray(image_8bit[2])
                    outTif.GetRasterBand(4).WriteArray(image_8bit[3])
    return 1

def partStretch(tif1,divisionSize,outStratchPath,tempPath):

    width,height,bands,geoTrans,proj = getTifSize(tif1)
    # bands = 1
    partWidth = partHeight = divisionSize

    if width % partWidth > 0 :
        widthNum = width // partWidth + 1
    else:
        widthNum =  width // partWidth
    if height % partHeight > 0:
        heightNum = height // partHeight +1
    else:
        heightNum = height // partHeight

    realName = os.path.split(tif1)[1].split(".")[0]

    driver = gdal.GetDriverByName("GTiff")
    outTif = driver.Create(outStratchPath, width, height, bands, gdal.GDT_Byte)
    if outTif!= None:
        outTif.SetGeoTransform(geoTrans)
        outTif.SetProjection(proj)
    for i in range(heightNum):
        for j in range(widthNum):

            startX = partWidth * j
            startY = partHeight * i

            if startX+partWidth<= width and startY+partHeight<=height:
                realPartWidth = partWidth
                realPartHeight = partHeight
            elif startX + partWidth > width and startY+partHeight<=height:
                realPartWidth = width - startX
                realPartHeight = partHeight
            elif startX+partWidth <= width  and startY+partHeight > height:
                realPartWidth = partWidth
                realPartHeight = height - startY
            elif startX + partWidth > width and startY+partHeight > height:
                realPartWidth = width - startX
                realPartHeight = height - startY

            partTifName = realName+str(i)+str(j)+".tif"
            partTifPath = os.path.join(tempPath,partTifName)
            divisionImg = gdal.Open(partTifPath)
            # if bands == 1:
            #     data1 = divisionImg.GetRasterBand(1).ReadAsArray(0,0,realPartWidth,realPartHeight)
            #     outPartBand = outTif.GetRasterBand(1)
            #     outPartBand.WriteArray(data1,startX,startY)
            for k in range(bands):
                data1 = divisionImg.GetRasterBand(k+1).ReadAsArray(0,0,realPartWidth,realPartHeight)
                outPartBand = outTif.GetRasterBand(k+1)
                outPartBand.WriteArray(data1,startX,startY)

if __name__ == "__main__":
    ylbit_path = 'D:/16bit8bit/big_test.tiff'
    bbit_path = 'D:/16bit8bit/8bit2.tif'
    temp = 'D:/16bit8bit/temp/'
    partDivisionForBoundary(ylbit_path,2000,temp)
    partStretch(ylbit_path,2000,bbit_path,temp)

参考链接:
https://blog.csdn.net/u014311125/article/details/93746867?utm_medium=distribute.pc_aggpage_search_result.none-task-blog-2allsobaiduend~default-1-93746867.nonecase&utm_term=16bit%E8%BD%AC8bit%20python&spm=1000.2123.3001.4430
这个链接里的代码出来的结果也是正常的,感觉好点还,但是有点慢

#!usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author  : 
@Email   : 
@Time    : 14:36
@Site    :
@File    : CompressImage.py
@Software: PyCharm
"""

"""
将16位遥感图像压缩至8位,并保持色彩一致
"""

import  gdal
import os
import glob
import numpy as np

def read_tiff(input_file):
    """
    读取影像
    :param input_file:输入影像
    :return:波段数据,仿射变换参数,投影信息、行数、列数、波段数
    """

    dataset = gdal.Open(input_file)
    rows = dataset.RasterYSize
    cols = dataset.RasterXSize

    geo = dataset.GetGeoTransform()
    proj = dataset.GetProjection()

    couts = dataset.RasterCount

    array_data = np.zeros((couts,rows,cols))

    for i in range(couts):
        band = dataset.GetRasterBand(i+1)
        array_data[i,:,:] = band.ReadAsArray()


    return array_data,geo,proj,rows,cols,3

def compress(origin_16,output_8):

    array_data,geo,proj,rows,cols,couts= read_tiff(origin_16)

    compress_data = np.zeros((couts,rows,cols))

    for i in range(couts):
        band_max = np.max(array_data[i,:,:])
        band_min = np.min(array_data[i,:,:])

        cutmin,cutmax=cumulativehistogram(array_data[i,:,:],rows,cols,band_min,band_max)

        compress_scale = (cutmax-cutmin)/255
        
        for j in range(rows):
            for k in range(cols):
                if(array_data[i,j,k]<cutmin):
                    array_data[i,j,k]=cutmin

                if(array_data[i,j,k]>cutmax):
                    array_data[i,j,k]=cutmax

                compress_data[i,j,k] = (array_data[i,j,k]-cutmin)/compress_scale

    write_tiff(output_8,compress_data,rows,cols,couts,geo,proj)

def write_tiff(output_file,array_data,rows,cols,counts,geo,proj):

    Driver = gdal.GetDriverByName("Gtiff")
    dataset = Driver.Create(output_file,cols,rows,counts,gdal.GDT_Byte)

    dataset.SetGeoTransform(geo)
    dataset.SetProjection(proj)

    for i in range(counts):
        band = dataset.GetRasterBand(i+1)
        band.WriteArray(array_data[i,:,:])


def cumulativehistogram(array_data,rows,cols,band_min,band_max):
    """
    累计直方图统计
    """

    # 逐波段统计最值

    gray_level = int(band_max-band_min+1)
    gray_array = np.zeros(gray_level)

    counts=0
    for row in range(rows):
        for col in range(cols):
            gray_array[int(array_data[row,col]-band_min)]+=1
            counts+=1

    count_percent2 = counts*0.02
    count_percent98 = counts*0.98

    cutmax=0
    cutmin=0

    for i in range(1,gray_level):
        gray_array[i]+=gray_array[i-1]
        if(gray_array[i]>=count_percent2 and gray_array[i-1]<=count_percent2):
            cutmin = i+band_min

        if(gray_array[i]>=count_percent98 and gray_array[i-1]<=count_percent98):
            cutmax = i+band_min

    return cutmin,cutmax

if __name__ == '__main__':

    origin_16="D:/16bit8bit/big_test.tiff"
    output_8 ="D:/16bit8bit/8bit3.tif"
    compress(origin_16,output_8)

原图
在这里插入图片描述

方法1:
在这里插入图片描述
方法2:参考链接里的
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

如雾如电

随缘

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值