1、遥感解译建立Mask
通过遥感解译,从完全覆盖遥感影像的面中切出想要的图斑,将图斑的属性赋值为255,其他区域赋值为0,同时Nodata值一定要重分类为0。
面转栅格,生成一个黑白Mask用来标记区域。 坐标系需要注意,两者完全一致。同时也要保证两者影像范围全部一样。
2、使用python滑动裁剪图像及标签
上述获得的影像数据和标签数据尺寸是几千几千,对于模型来说过大,这里使用python+gdal库自动裁剪YOLO适用的640X640(或其他尺寸,2的n次方)的图像,并按一定命名方式存储(0,1,2,3......命名) 同时对影像数据和标签数据裁剪。 代码:
运行中,这一步主要踩的雷就是GDAL库,这个看版本,有的是直接import,有的不行会提示没有,pip下载却会提醒已经下载。那就需要改为 from osgeo import gdal 。
import os
import sys
try:
import gdal
except:
from osgeo import gdal
import numpy as np
from PIL import Image
# 保存tif文件函数
def writeTiff(im_data, im_geotrans, im_proj, path):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
im_bands, im_height, im_width = im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
if (dataset != None):
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
# 像素坐标和地理坐标仿射变换
def CoordTransf(Xpixel, Ypixel, GeoTransform):
XGeo = GeoTransform[0] + GeoTransform[1] * Xpixel + Ypixel * GeoTransform[2]
YGeo = GeoTransform[3] + GeoTransform[4] * Xpixel + Ypixel * GeoTransform[5]
return XGeo, YGeo
'''
滑动窗口裁剪Tif影像
TifPath 影像路径
SavePath 裁剪后影像保存目录
CropSize 裁剪尺寸
RepetitionRate 重叠度
'''
def TifCrop(TifPath, SavePath, CropSize, RepetitionRate):
print("--------------------裁剪影像-----------------------")
CropSize = int(CropSize)
RepetitionRate = float(RepetitionRate)
dataset_img = gdal.Open(TifPath)
if dataset_img == None:
print(TifPath + "文件无法打开")
if not os.path.exists(SavePath):
os.makedirs(SavePath)
width = dataset_img.RasterXSize # 获取行列数
height = dataset_img.RasterYSize
bands = dataset_img.RasterCount # 获取波段数
print("行数为:", height)
print("列数为:", width)
print("波段数为:", bands)
proj = dataset_img.GetProjection() # 获取投影信息
geotrans = dataset_img.GetGeoTransform() # 获取仿射矩阵信息
img = dataset_img.ReadAsArray(0, 0, width, height) # 获取数据
# 行上图像块数目
RowNum = int((height - CropSize * RepetitionRate) / (CropSize * (1 - RepetitionRate)))
# 列上图像块数目
ColumnNum = int((width - CropSize * RepetitionRate) / (CropSize * (1 - RepetitionRate)))
print("裁剪后行影像数为:", RowNum)
print("裁剪后列影像数为:", ColumnNum)
# 获取当前文件夹的文件个数len,并以len+1命名即将裁剪得到的图像
new_name = len(os.listdir(SavePath)) + 1
# 裁剪图片,重复率为RepetitionRate
for i in range(RowNum):
for j in range(ColumnNum):
# 如果图像是单波段
if (bands == 1):
cropped = img[
int(i * CropSize * (1 - RepetitionRate)): int(i * CropSize * (1 - RepetitionRate)) + CropSize,
int(j * CropSize * (1 - RepetitionRate)): int(j * CropSize * (1 - RepetitionRate)) + CropSize]
# 如果图像是多波段
else:
cropped = img[:,
int(i * CropSize * (1 - RepetitionRate)): int(i * CropSize * (1 - RepetitionRate)) + CropSize,
int(j * CropSize * (1 - RepetitionRate)): int(j * CropSize * (1 - RepetitionRate)) + CropSize]
# 获取地理坐标
XGeo, YGeo = CoordTransf(int(j * CropSize * (1 - RepetitionRate)),
int(i * CropSize * (1 - RepetitionRate)),
geotrans)
crop_geotrans = (XGeo, geotrans[1], geotrans[2], YGeo, geotrans[4], geotrans[5])
# 生成Tif图像
writeTiff(cropped, crop_geotrans, proj, SavePath + "/%d.tif" % new_name)
# 文件名 + 1
new_name = new_name + 1
# 向前裁剪最后一行
for i in range(RowNum):
if (bands == 1):
cropped = img[int(i * CropSize * (1 - RepetitionRate)): int(i * CropSize * (1 - RepetitionRate)) + CropSize,
(width - CropSize): width]
else:
cropped = img[:,
int(i * CropSize * (1 - RepetitionRate)): int(i * CropSize * (1 - RepetitionRate)) + CropSize,
(width - CropSize): width]
# 获取地理坐标
XGeo, YGeo = CoordTransf(width - CropSize,
int(i * CropSize * (1 - RepetitionRate)),
geotrans)
crop_geotrans = (XGeo, geotrans[1], geotrans[2], YGeo, geotrans[4], geotrans[5])
# 生成Tif影像
writeTiff(cropped, crop_geotrans, proj, SavePath + "/%d.tif" % new_name)
new_name = new_name + 1
# 向前裁剪最后一列
for j in range(ColumnNum):
if (bands == 1):
cropped = img[(height - CropSize): height,
int(j * CropSize * (1 - RepetitionRate)): int(j * CropSize * (1 - RepetitionRate)) + CropSize]
else:
cropped = img[:,
(height - CropSize): height,
int(j * CropSize * (1 - RepetitionRate)): int(j * CropSize * (1 - RepetitionRate)) + CropSize]
# 获取地理坐标
XGeo, YGeo = CoordTransf(int(j * CropSize * (1 - RepetitionRate)),
height - CropSize,
geotrans)
crop_geotrans = (XGeo, geotrans[1], geotrans[2], YGeo, geotrans[4], geotrans[5])
# 生成tif影像
writeTiff(cropped, crop_geotrans, proj, SavePath + "/%d.tif" % new_name)
# 文件名 + 1
new_name = new_name + 1
# 裁剪右下角
if (bands == 1):
cropped = img[(height - CropSize): height,
(width - CropSize): width]
else:
cropped = img[:,
(height - CropSize): height,
(width - CropSize): width]
XGeo, YGeo = CoordTransf(width - CropSize,
height - CropSize,
geotrans)
crop_geotrans = (XGeo, geotrans[1], geotrans[2], YGeo, geotrans[4], geotrans[5])
# 生成Tif影像
writeTiff(cropped, crop_geotrans, proj, SavePath + "/%d.tif" % new_name)
new_name = new_name + 1
if __name__ == '__main__':
# 将影像裁剪为重复率为0的640×641的数据集
TifCrop(r"F:/test/MJ.tif",
r"F:/test/TrainMJ", 640, 0)
3、裁切后训练集中大量黑边剔除
裁切后的影像如上图所述,由于影像本身不是矩形,存在不少Nodata导致的空值黑边以及没有目标区域的影像切片,利用python和GDAL剔除掉,只留下有标签的影像切片和标签Mask。
运行中需要注意,输入的路径采用的路径通配符一定是/,反斜杠,同时路径的最后有一个反斜杠。
此步骤后,从825张图像和标签中选出128对含有目标地物的图像和标签。
import os, sys
try:
import gdal
except:
from osgeo import gdal
from numba import jit
import numpy as np
# 保存tif文件函数
def writeTiff(im_data, im_geotrans, im_proj, path):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
im_bands, im_height, im_width = im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
if (dataset != None):
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
'''
影像筛选函数
OriTifArrayPath 原始影像路径
MaskTifArrayPath 标签影像路径
EndImg 筛选后的原始影像路径
EndMask 筛选后的标签影像路径
'''
def TifScreen(OriTifArrayPath, MaskTifArrayPath, EndImg, EndMask):
if not os.path.exists(EndImg):
os.makedirs(EndImg)
if not os.path.exists(EndMask):
os.makedirs(EndMask)
imgList = os.listdir(OriTifArrayPath) # 读入文件夹
imgList.sort(key=lambda x: int(x.split('.')[0])) # 按照数字进行排序后按顺序读取文件夹下的图片
ImgArray = [] #创建队列
geotransArray = []
projArray = []
num_img = len(imgList)
for TifPath in imgList:
dataset_img = gdal.Open(OriTifArrayPath + TifPath)
width_crop = dataset_img.RasterXSize # 获取行列数
height_crop = dataset_img.RasterYSize
bands = dataset_img.RasterCount # 获取波段数
proj = dataset_img.GetProjection() # 获取投影信息
geotrans = dataset_img.GetGeoTransform() # 获取仿射矩阵信息
img = dataset_img.ReadAsArray(0, 0, width_crop, height_crop) # 获取数据
# print(TifPath)
ImgArray.append(img) # 将影像按顺序存入队列
geotransArray.append(geotrans)
projArray.append(proj)
print("行数为:", height_crop)
print("列数为:", width_crop)
print("波段数为:", bands)
print("读取全部影像数量为:", len(ImgArray))
MaskList = os.listdir(MaskTifArrayPath) # 读入文件夹
MaskList.sort(key=lambda x: int(x.split('.')[0])) # 按照数字进行排序后按顺序读取文件夹下的图片
MaskArray = []
geotrans_maskArray = []
proj_maskArray = []
for MaskTifPath in MaskList:
dataset_mask = gdal.Open(MaskTifArrayPath + MaskTifPath)
width_mask = dataset_mask.RasterXSize # 获取行列数
height_mask = dataset_mask.RasterYSize
bands_mask = dataset_mask.RasterCount # 获取波段数
proj_mask = dataset_mask.GetProjection() # 获取投影信息
geotrans_mask = dataset_mask.GetGeoTransform() # 获取仿射矩阵信息
mask = dataset_mask.ReadAsArray(0, 0, width_mask, height_mask) # 获取数据
MaskArray.append(mask) # 将影像按顺序存入队列
geotrans_maskArray.append(geotrans_mask)
proj_maskArray.append(proj_mask)
print("行数为:", height_mask)
print("列数为:", width_mask)
print("波段数为:", bands_mask)
print("读取全部掩膜数量为:", len(MaskArray))
for i in range(len(MaskArray)):
count = test_fast(MaskArray[i])
name = i + 1
if count != 0:
print("图像保存成功")
writeTiff(ImgArray[i], geotransArray[i], projArray[i], EndImg + "/%d.tif" % name)
writeTiff(MaskArray[i], geotrans_maskArray[i], proj_maskArray[i], EndMask + "/%d.tif" % name)
# 循环加速,判断影像中是否有255标签
@jit(nopython=True)
def test_fast(img):
count = 0
for row in range(img.shape[0]):
for col in range(img.shape[1]):
if img[row, col] != 0:
count = count + 1
return count
if __name__ == '__main__':
# 筛选出包含有255标签的图像
TifScreen(r"F:/internship/code/ImagePreprocessing/data/test/img/",
r"F:/internship/code/ImagePreprocessing/data/test/mask/",
r"F:/internship/code/ImagePreprocessing/data/test/imgend",
r"F:/internship/code/ImagePreprocessing/data/test/maskend")