对geotiff多光谱影像批量数据增广(旋转,翻转)
在生成深度学习大量训练数据集时,我们通常都是对普通的rgb三通道图片进行数据增广,如旋转,翻转等,但是对geotif这种包含地理坐标信息,且包含多个通道的数据增广,据我了解目前并没有相关软件可以实现,网络上也没有类似的代码资源。难道只有我一个人事多,有这样”奇葩“的需求吗?,现在把我的成果免费分享给大家,如果有同学有相同需求,且觉得这个资源有用的话,麻烦给点个赞吧,哈哈哈。。。。。。
废话不多说直接上代码
运行代码只需要在主函数中,修改【 img_file】参数即可,GDAL库请自行下载安装。
import os
from osgeo import gdal
import numpy as np
def Expand_trainset_pic(img_file):
"""
对指定目录下的图片进行简单数据扩充,包括对原图
进行90°、180°,270°逆时针旋转
和水平镜像以及垂直镜像
:param img_file: 需要批量扩充的图片所在文件夹
"""
#获取文件目录下的所有文件名
file_name_list = os.listdir(img_file)
for name in file_name_list:
#获取图片文件全路径
img_path = os.path.join(img_file, name)
filename = os.path.splitext(name)[0]
# 获取扩展名
file_ext = os.path.splitext(name)[1]
print(img_path)
dataset = gdal.Open(img_path)
im_width = dataset.RasterXSize #栅格矩阵的列数
im_height = dataset.RasterYSize #栅格矩阵的行数
im_bands = dataset.RasterCount #波段数
im_geotrans = dataset.GetGeoTransform()#获取仿射矩阵信息
im_proj = dataset.GetProjection()#获取投影信息
im_data = dataset.ReadAsArray(0,0,im_width,im_height)#获取数据
#print("长:",im_width,"宽:",im_height,"高:",im_bands,"放射矩阵信息",im_geotrans,"投影信息",im_proj)
#print(im_data)
b1=im_data[0]
b2=im_data[1]
b3=im_data[2]
b4=im_data[3]
#垂直翻转
fanzhuany_out=fanzhuany(b1, b2, b3, b4)
fanzhuany_name=filename+"_y"+file_ext
fanzhuany_save_path=os.path.join(img_file, fanzhuany_name)
writeTiff(fanzhuany_out, im_width, im_height, im_bands, im_geotrans, im_proj, fanzhuany_save_path)
#水平翻转
fanzhaunx_out=fanzhuanx(b1, b2, b3, b4)
fanzhaunx_name=filename+"_x"+file_ext
fanzhaunx_save_path=os.path.join(img_file, fanzhaunx_name)
writeTiff(fanzhaunx_out, im_width, im_height, im_bands, im_geotrans, im_proj, fanzhaunx_save_path)
#逆时针旋转90°
xuanzhuan90_out=xuanzhaun90(b1, b2, b3, b4)
xuanzhuan90_name=filename+"_90"+file_ext
xuanzhuan90_save_path=os.path.join(img_file, xuanzhuan90_name)
writeTiff(xuanzhuan90_out, im_width, im_height, im_bands, im_geotrans, im_proj, xuanzhuan90_save_path)
# 逆时针旋转180°
xuanzhuan180_out = xuanzhaun180(b1, b2, b3, b4)
xuanzhuan180_name=filename+"_180"+file_ext
xuanzhuan180_save_path=os.path.join(img_file, xuanzhuan180_name)
writeTiff(xuanzhuan180_out, im_width, im_height, im_bands, im_geotrans, im_proj, xuanzhuan180_save_path)
#逆时针旋转270°
xuanzhuan270_out = xuanzhaun270(b1, b2, b3, b4)
xuanzhuan270_name=filename+"_270"+file_ext
xuanzhuan270_save_path=os.path.join(img_file, xuanzhuan270_name)
writeTiff(xuanzhuan270_out, im_width, im_height, im_bands, im_geotrans, im_proj, xuanzhuan270_save_path)
print("完成对图片:", img_path, " 的扩充")
print("完成对所有图片的扩充!")
def fanzhuany(b1, b2, b3, b4):
fanzhuany_b1 = np.flip(b1, axis=0) # 上下翻转
fanzhuany_b2 = np.flip(b2, axis=0) # 上下翻转
fanzhuany_b3 = np.flip(b3, axis=0) # 上下翻转
fanzhuany_b4 = np.flip(b4, axis=0) # 上下翻转
fanzhuany_out = np.array([fanzhuany_b1, fanzhuany_b2,fanzhuany_b3,fanzhuany_b4])
return fanzhuany_out
def fanzhuanx(b1, b2, b3, b4):
fanzhuanx_b1 = np.flip(b1, axis=1) # 左右翻转
fanzhuanx_b2 = np.flip(b2, axis=1) # 左右翻转
fanzhuanx_b3 = np.flip(b3, axis=1) # 左右翻转
fanzhuanx_b4 = np.flip(b4, axis=1) # 左右翻转
fanzhuanx_out = np.array([fanzhuanx_b1, fanzhuanx_b2,fanzhuanx_b3,fanzhuanx_b4])
return fanzhuanx_out
def xuanzhaun90(b1, b2, b3, b4):
xuanzhaun90_b1 = np.rot90(b1, 1) # 逆时针旋转
xuanzhaun90_b2 = np.rot90(b2, 1) # 逆时针旋转
xuanzhaun90_b3 = np.rot90(b3, 1) # 逆时针旋转
xuanzhaun90_b4 = np.rot90(b4, 1) # 逆时针旋转
xuanzhaun90_out = np.array([xuanzhaun90_b1, xuanzhaun90_b2,xuanzhaun90_b3,xuanzhaun90_b4])
return xuanzhaun90_out
def xuanzhaun180(b1, b2, b3, b4):
xuanzhaun180_b1 = np.rot90(b1, 2) # 逆时针旋转
xuanzhaun180_b2 = np.rot90(b2, 2) # 逆时针旋转
xuanzhaun180_b3 = np.rot90(b3, 2) # 逆时针旋转
xuanzhaun180_b4 = np.rot90(b4, 2) # 逆时针旋转
xuanzhaun180_out = np.array([xuanzhaun180_b1, xuanzhaun180_b2,xuanzhaun180_b3,xuanzhaun180_b4])
return xuanzhaun180_out
def xuanzhaun270(b1, b2, b3, b4):
xuanzhaun270_b1 = np.rot90(b1, 3) # 逆时针旋转
xuanzhaun270_b2 = np.rot90(b2, 3) # 逆时针旋转
xuanzhaun270_b3 = np.rot90(b3, 3) # 逆时针旋转
xuanzhaun270_b4 = np.rot90(b4, 3) # 逆时针旋转
xuanzhaun270_out = np.array([xuanzhaun270_b1, xuanzhaun270_b2,xuanzhaun270_b3,xuanzhaun270_b4])
return xuanzhaun270_out
def writeTiff(im_data,im_width,im_height,im_bands,im_geotrans,im_proj,path):
"""
写入tif格式
:param im_data: 波段数组
:param im_width: 图像宽度
:param im_height: 图像高度
:param im_bands: 波段数
:param im_geotrans: 仿射矩阵
:param im_proj: 投影
:param path: 写入的文件路径
:return:
"""
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
#创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
if(dataset!= None):
dataset.SetGeoTransform(im_geotrans) #写入仿射变换参数
dataset.SetProjection(im_proj) #写入投影
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
#程序主入口
if __name__ == "__main__":
img_file = r'F:\train\Image' #在此输入需要转换的图片所在文件夹
Expand_trainset_pic(img_file)#调用自定义扩充函数