import os
import numpy as np
import gdal
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
del dataset
return im_proj,im_geotrans,im_width, im_height,im_data
def write_img(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
def write_img_(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, gdal.GDT_Byte)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
out = np.zeros_like(bands).astype(np.float32)
# a = 0
# b = 65535
a = img_min
b = img_max
# print(a, b)
c = np.percentile(bands[:, :], lower_percent)
d = np.percentile(bands[:, :], higher_percent)
# x = d-c
# if (x==0).any():
# t = 0
# else:
t = a + (bands[:, :] - c) * (b - a) / (d - c)
t[t < a] = a
t[t > b] = b
out[:, :] = t
return out
def getTifSize(tif):
dataSet = gdal.Open(tif)
width = dataSet.RasterXSize
height = dataSet.RasterYSize
bands = dataSet.RasterCount
geoTrans = dataSet.GetGeoTransform()
proj = dataSet.GetProjection()
return width,height,bands,geoTrans,proj
def partDivisionForBoundary(tif1,divisionSize,tempPath):
width,height,bands,geoTrans,proj = getTifSize(tif1)
partWidth = partHeight = divisionSize
if width % partWidth > 0 :
widthNum = width // partWidth + 1
else:
widthNum = width // partWidth
if height % partHeight > 0:
heightNum = height // partHeight +1
else:
heightNum = height // partHeight
realName = os.path.split(tif1)[1].split(".")[0]
tif1 = gdal.Open(tif1)
im_data = tif1.ReadAsArray(0,0,width,height)
min_16bit = np.min(im_data)
max_16bit = np.max(im_data)
for i in range(heightNum):
for j in range(widthNum):
startX = partWidth * j
startY = partHeight * i
if startX+partWidth<= width and startY+partHeight<=height:
realPartWidth = partWidth
realPartHeight = partHeight
elif startX + partWidth > width and startY+partHeight<=height:
realPartWidth = width - startX
realPartHeight = partHeight
elif startX+partWidth <= width and startY+partHeight > height:
realPartWidth = partWidth
realPartHeight = height - startY
elif startX + partWidth > width and startY+partHeight > height:
realPartWidth = width - startX
realPartHeight = height - startY
outName = realName + str(i)+str(j)+".tif"
outPath = os.path.join(tempPath,outName)
if not os.path.exists(outPath):
driver = gdal.GetDriverByName("GTiff")
outTif = driver.Create(outPath,realPartWidth,realPartHeight, bands, gdal.GDT_Byte)
outTif.SetGeoTransform(geoTrans)
outTif.SetProjection(proj)
data1 = tif1.ReadAsArray(startX,startY,realPartWidth,realPartHeight)
# type 1
image_8bit = np.array(np.rint(255 * ((data1 - min_16bit) / (max_16bit - min_16bit))), dtype=np.uint8)
# type 2
#image_8bit = int8(data1*256/65536)
if bands == 1:
outTif.GetRasterBand(1).WriteArray(image_8bit[0])
elif bands == 4:
outTif.GetRasterBand(1).WriteArray(image_8bit[0])
outTif.GetRasterBand(2).WriteArray(image_8bit[1])
outTif.GetRasterBand(3).WriteArray(image_8bit[2])
outTif.GetRasterBand(4).WriteArray(image_8bit[3])
return 1
def partStretch(tif1,divisionSize,outStratchPath,tempPath):
width,height,bands,geoTrans,proj = getTifSize(tif1)
# bands = 1
partWidth = partHeight = divisionSize
if width % partWidth > 0 :
widthNum = width // partWidth + 1
else:
widthNum = width // partWidth
if height % partHeight > 0:
heightNum = height // partHeight +1
else:
heightNum = height // partHeight
realName = os.path.split(tif1)[1].split(".")[0]
driver = gdal.GetDriverByName("GTiff")
outTif = driver.Create(outStratchPath, width, height, bands, gdal.GDT_Byte)
if outTif!= None:
outTif.SetGeoTransform(geoTrans)
outTif.SetProjection(proj)
for i in range(heightNum):
for j in range(widthNum):
startX = partWidth * j
startY = partHeight * i
if startX+partWidth<= width and startY+partHeight<=height:
realPartWidth = partWidth
realPartHeight = partHeight
elif startX + partWidth > width and startY+partHeight<=height:
realPartWidth = width - startX
realPartHeight = partHeight
elif startX+partWidth <= width and startY+partHeight > height:
realPartWidth = partWidth
realPartHeight = height - startY
elif startX + partWidth > width and startY+partHeight > height:
realPartWidth = width - startX
realPartHeight = height - startY
partTifName = realName+str(i)+str(j)+".tif"
partTifPath = os.path.join(tempPath,partTifName)
divisionImg = gdal.Open(partTifPath)
# if bands == 1:
# data1 = divisionImg.GetRasterBand(1).ReadAsArray(0,0,realPartWidth,realPartHeight)
# outPartBand = outTif.GetRasterBand(1)
# outPartBand.WriteArray(data1,startX,startY)
for k in range(bands):
data1 = divisionImg.GetRasterBand(k+1).ReadAsArray(0,0,realPartWidth,realPartHeight)
outPartBand = outTif.GetRasterBand(k+1)
outPartBand.WriteArray(data1,startX,startY)
if __name__ == "__main__":
ylbit_path = 'D:/16bit8bit/big_test.tiff'
bbit_path = 'D:/16bit8bit/8bit2.tif'
temp = 'D:/16bit8bit/temp/'
partDivisionForBoundary(ylbit_path,2000,temp)
partStretch(ylbit_path,2000,bbit_path,temp)
参考链接:
https://blog.csdn.net/u014311125/article/details/93746867?utm_medium=distribute.pc_aggpage_search_result.none-task-blog-2allsobaiduend~default-1-93746867.nonecase&utm_term=16bit%E8%BD%AC8bit%20python&spm=1000.2123.3001.4430
这个链接里的代码出来的结果也是正常的,感觉好点还,但是有点慢
#!usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author :
@Email :
@Time : 14:36
@Site :
@File : CompressImage.py
@Software: PyCharm
"""
"""
将16位遥感图像压缩至8位,并保持色彩一致
"""
import gdal
import os
import glob
import numpy as np
def read_tiff(input_file):
"""
读取影像
:param input_file:输入影像
:return:波段数据,仿射变换参数,投影信息、行数、列数、波段数
"""
dataset = gdal.Open(input_file)
rows = dataset.RasterYSize
cols = dataset.RasterXSize
geo = dataset.GetGeoTransform()
proj = dataset.GetProjection()
couts = dataset.RasterCount
array_data = np.zeros((couts,rows,cols))
for i in range(couts):
band = dataset.GetRasterBand(i+1)
array_data[i,:,:] = band.ReadAsArray()
return array_data,geo,proj,rows,cols,3
def compress(origin_16,output_8):
array_data,geo,proj,rows,cols,couts= read_tiff(origin_16)
compress_data = np.zeros((couts,rows,cols))
for i in range(couts):
band_max = np.max(array_data[i,:,:])
band_min = np.min(array_data[i,:,:])
cutmin,cutmax=cumulativehistogram(array_data[i,:,:],rows,cols,band_min,band_max)
compress_scale = (cutmax-cutmin)/255
for j in range(rows):
for k in range(cols):
if(array_data[i,j,k]<cutmin):
array_data[i,j,k]=cutmin
if(array_data[i,j,k]>cutmax):
array_data[i,j,k]=cutmax
compress_data[i,j,k] = (array_data[i,j,k]-cutmin)/compress_scale
write_tiff(output_8,compress_data,rows,cols,couts,geo,proj)
def write_tiff(output_file,array_data,rows,cols,counts,geo,proj):
Driver = gdal.GetDriverByName("Gtiff")
dataset = Driver.Create(output_file,cols,rows,counts,gdal.GDT_Byte)
dataset.SetGeoTransform(geo)
dataset.SetProjection(proj)
for i in range(counts):
band = dataset.GetRasterBand(i+1)
band.WriteArray(array_data[i,:,:])
def cumulativehistogram(array_data,rows,cols,band_min,band_max):
"""
累计直方图统计
"""
# 逐波段统计最值
gray_level = int(band_max-band_min+1)
gray_array = np.zeros(gray_level)
counts=0
for row in range(rows):
for col in range(cols):
gray_array[int(array_data[row,col]-band_min)]+=1
counts+=1
count_percent2 = counts*0.02
count_percent98 = counts*0.98
cutmax=0
cutmin=0
for i in range(1,gray_level):
gray_array[i]+=gray_array[i-1]
if(gray_array[i]>=count_percent2 and gray_array[i-1]<=count_percent2):
cutmin = i+band_min
if(gray_array[i]>=count_percent98 and gray_array[i-1]<=count_percent98):
cutmax = i+band_min
return cutmin,cutmax
if __name__ == '__main__':
origin_16="D:/16bit8bit/big_test.tiff"
output_8 ="D:/16bit8bit/8bit3.tif"
compress(origin_16,output_8)
原图
方法1:
方法2:参考链接里的