基于python将遥感影像裁剪成深度学习数据集并拼接

基于python将遥感影像裁剪成深度学习数据集并拼接

更改二分类mask为0-255

import os, sys
try:
    import gdal
except:
    from osgeo import gdal
import numpy as np
from osgeo import gdal_array
from numba import jit

# 保存tif文件函数
def writeTiff(im_data, im_geotrans, im_proj, path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
        im_bands, im_height, im_width = im_data.shape
    # 创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
    if (dataset != None):
        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        dataset.SetProjection(im_proj)  # 写入投影
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset

'''
更改标签图像像素
TifPath 影像路径
'''
def changemaskpixel(TifPath, SavePath):
    dataset_img = gdal.Open(TifPath)
    if dataset_img == None:
        print(TifPath + "文件无法打开")


    width = dataset_img.RasterXSize  # 获取行列数
    height = dataset_img.RasterYSize
    bands = dataset_img.RasterCount  # 获取波段数
    print("行数为:", height)
    print("列数为:", width)
    print("波段数为:", bands)

    proj = dataset_img.GetProjection()  # 获取投影信息
    geotrans = dataset_img.GetGeoTransform()  # 获取仿射矩阵信息
    img = dataset_img.ReadAsArray(0, 0, width, height)  # 获取数据
    # print(img.shape[0], img.shape[1])

    go_fast(img)

    # 生成Tif图像
    writeTiff(img, geotrans, proj, SavePath)
    # gdal_array.SaveArray(img.astype(gdal_array.numpy.uint8),
    #                      SavePath, format="GTIFF", prototype='')

# 循环加速,将标签影像赋值为0-255
@jit(nopython=True)
def go_fast(img):
    for row in range(img.shape[0]):
        for col in range(img.shape[1]):
            # print(img[row, col])
            if img[row, col] == 1:
                img[row, col] = 255
            else:
                img[row, col] = 0
    return img


if __name__ == '__main__':
    # # 更改标签图像的像素为0-255
    changemaskpixel("F:/internship/code/accuracy/results/mask.tif",
                    "F:/internship/code/accuracy/results/groundtruth.tif")

裁剪

裁剪代码参考:python遥感图像裁剪成深度学习样本_支持多波段

import os
import sys
try:
    import gdal
except:
    from osgeo import gdal
import numpy as np
from PIL import Image



# 保存tif文件函数
def writeTiff(im_data, im_geotrans, im_proj, path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
        im_bands, im_height, im_width = im_data.shape
    # 创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
    if (dataset != None):
        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        dataset.SetProjection(im_proj)  # 写入投影
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset



# 像素坐标和地理坐标仿射变换
def CoordTransf(Xpixel, Ypixel, GeoTransform):
    XGeo = GeoTransform[0] + GeoTransform[1] * Xpixel + Ypixel * GeoTransform[2]
    YGeo = GeoTransform[3] + GeoTransform[4] * Xpixel + Ypixel * GeoTransform[5]
    return XGeo, YGeo


'''
滑动窗口裁剪Tif影像
TifPath 影像路径
SavePath 裁剪后影像保存目录
CropSize 裁剪尺寸
RepetitionRate 重叠度
'''
def TifCrop(TifPath, SavePath, CropSize, RepetitionRate):
    print("--------------------裁剪影像-----------------------")
    CropSize = int(CropSize)
    RepetitionRate = float(RepetitionRate)
    dataset_img = gdal.Open(TifPath)
    if dataset_img == None:
        print(TifPath + "文件无法打开")

    if not os.path.exists(SavePath):
        os.makedirs(SavePath)

    width = dataset_img.RasterXSize     # 获取行列数
    height = dataset_img.RasterYSize
    bands = dataset_img.RasterCount  # 获取波段数
    print("行数为:", height)
    print("列数为:", width)
    print("波段数为:", bands)

    proj = dataset_img.GetProjection()      # 获取投影信息
    geotrans = dataset_img.GetGeoTransform()        # 获取仿射矩阵信息
    img = dataset_img.ReadAsArray(0, 0, width, height)  # 获取数据

    #  行上图像块数目
    RowNum = int((height - CropSize * RepetitionRate) / (CropSize * (1 - RepetitionRate)))
    #  列上图像块数目
    ColumnNum = int((width - CropSize * RepetitionRate) / (CropSize * (1 - RepetitionRate)))
    print("裁剪后行影像数为:", RowNum)
    print("裁剪后列影像数为:", ColumnNum)

    # 获取当前文件夹的文件个数len,并以len+1命名即将裁剪得到的图像
    new_name = len(os.listdir(SavePath)) + 1

    # 裁剪图片,重复率为RepetitionRate
    for i in range(RowNum):
        for j in range(ColumnNum):
            # 如果图像是单波段
            if (bands == 1):
                cropped = img[
                          int(i * CropSize * (1 - RepetitionRate)): int(i * CropSize * (1 - RepetitionRate)) + CropSize,
                          int(j * CropSize * (1 - RepetitionRate)): int(j * CropSize * (1 - RepetitionRate)) + CropSize]
            # 如果图像是多波段
            else:
                cropped = img[:,
                          int(i * CropSize * (1 - RepetitionRate)): int(i * CropSize * (1 - RepetitionRate)) + CropSize,
                          int(j * CropSize * (1 - RepetitionRate)): int(j * CropSize * (1 - RepetitionRate)) + CropSize]
            # 获取地理坐标
            XGeo, YGeo = CoordTransf(int(j * CropSize * (1 - RepetitionRate)),
                                     int(i * CropSize * (1 - RepetitionRate)),
                                     geotrans)
            crop_geotrans = (XGeo, geotrans[1], geotrans[2], YGeo, geotrans[4], geotrans[5])

            # 生成Tif图像
            writeTiff(cropped, crop_geotrans, proj, SavePath + "/%d.tif" % new_name)

            # 文件名 + 1
            new_name = new_name + 1
    # 向前裁剪最后一行
    for i in range(RowNum):
        if (bands == 1):
            cropped = img[int(i * CropSize * (1 - RepetitionRate)): int(i * CropSize * (1 - RepetitionRate)) + CropSize,
                      (width - CropSize): width]
        else:
            cropped = img[:,
                      int(i * CropSize * (1 - RepetitionRate)): int(i * CropSize * (1 - RepetitionRate)) + CropSize,
                      (width - CropSize): width]
        # 获取地理坐标
        XGeo, YGeo = CoordTransf(width - CropSize,
                                 int(i * CropSize * (1 - RepetitionRate)),
                                 geotrans)
        crop_geotrans = (XGeo, geotrans[1], geotrans[2], YGeo, geotrans[4], geotrans[5])

        # 生成Tif影像
        writeTiff(cropped, crop_geotrans, proj, SavePath + "/%d.tif" % new_name)

        new_name = new_name + 1
    # 向前裁剪最后一列
    for j in range(ColumnNum):
        if (bands == 1):
            cropped = img[(height - CropSize): height,
                      int(j * CropSize * (1 - RepetitionRate)): int(j * CropSize * (1 - RepetitionRate)) + CropSize]
        else:
            cropped = img[:,
                      (height - CropSize): height,
                      int(j * CropSize * (1 - RepetitionRate)): int(j * CropSize * (1 - RepetitionRate)) + CropSize]
        # 获取地理坐标
        XGeo, YGeo = CoordTransf(int(j * CropSize * (1 - RepetitionRate)),
                                 height - CropSize,
                                 geotrans)
        crop_geotrans = (XGeo, geotrans[1], geotrans[2], YGeo, geotrans[4], geotrans[5])
        # 生成tif影像
        writeTiff(cropped, crop_geotrans, proj, SavePath + "/%d.tif" % new_name)

        # 文件名 + 1
        new_name = new_name + 1
    # 裁剪右下角
    if (bands == 1):
        cropped = img[(height - CropSize): height,
                  (width - CropSize): width]
    else:
        cropped = img[:,
                  (height - CropSize): height,
                  (width - CropSize): width]

    XGeo, YGeo = CoordTransf(width - CropSize,
                             height - CropSize,
                             geotrans)
    crop_geotrans = (XGeo, geotrans[1], geotrans[2], YGeo, geotrans[4], geotrans[5])
    # 生成Tif影像
    writeTiff(cropped, crop_geotrans, proj, SavePath + "/%d.tif" % new_name)

    new_name = new_name + 1


if __name__ == '__main__':
    # 将影像裁剪为重复率为0.3的2048×2048的数据集
    TifCrop(r"F:/test/baiyun_ml.tif",
            r"F:/test/img", 2048, 0.3)

剔除全为背景的图像

import os, sys
try:
    import gdal
except:
    from osgeo import gdal
from numba import jit
import numpy as np

# 保存tif文件函数
def writeTiff(im_data, im_geotrans, im_proj, path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
        im_bands, im_height, im_width = im_data.shape
    # 创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
    if (dataset != None):
        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        dataset.SetProjection(im_proj)  # 写入投影
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset

'''
水体影像筛选函数
OriTifArrayPath 原始影像路径
MaskTifArrayPath 标签影像路径
EndImg 筛选后的原始影像路径
EndMask 筛选后的标签影像路径
'''
def TifScreen(OriTifArrayPath, MaskTifArrayPath, EndImg, EndMask):
    if not os.path.exists(EndImg):
        os.makedirs(EndImg)
    if not os.path.exists(EndMask):
        os.makedirs(EndMask)
    imgList = os.listdir(OriTifArrayPath)  # 读入文件夹
    imgList.sort(key=lambda x: int(x.split('.')[0]))  # 按照数字进行排序后按顺序读取文件夹下的图片
    ImgArray = []   #创建队列
    geotransArray = []
    projArray = []
    num_img = len(imgList)
    for TifPath in imgList:
        dataset_img = gdal.Open(OriTifArrayPath + TifPath)
        width_crop = dataset_img.RasterXSize  # 获取行列数
        height_crop = dataset_img.RasterYSize
        bands = dataset_img.RasterCount  # 获取波段数
        proj = dataset_img.GetProjection()  # 获取投影信息
        geotrans = dataset_img.GetGeoTransform()  # 获取仿射矩阵信息
        img = dataset_img.ReadAsArray(0, 0, width_crop, height_crop)  # 获取数据
        # print(TifPath)
        ImgArray.append(img)    # 将影像按顺序存入队列
        geotransArray.append(geotrans)
        projArray.append(proj)
    print("行数为:", height_crop)
    print("列数为:", width_crop)
    print("波段数为:", bands)
    print("读取全部影像数量为:", len(ImgArray))

    MaskList = os.listdir(MaskTifArrayPath)  # 读入文件夹
    MaskList.sort(key=lambda x: int(x.split('.')[0]))  # 按照数字进行排序后按顺序读取文件夹下的图片
    MaskArray = []
    geotrans_maskArray = []
    proj_maskArray = []
    for MaskTifPath in MaskList:
        dataset_mask = gdal.Open(MaskTifArrayPath + MaskTifPath)
        width_mask = dataset_mask.RasterXSize  # 获取行列数
        height_mask = dataset_mask.RasterYSize
        bands_mask = dataset_mask.RasterCount  # 获取波段数
        proj_mask = dataset_mask.GetProjection()  # 获取投影信息
        geotrans_mask = dataset_mask.GetGeoTransform()  # 获取仿射矩阵信息
        mask = dataset_mask.ReadAsArray(0, 0, width_mask, height_mask)  # 获取数据
        MaskArray.append(mask)    # 将影像按顺序存入队列
        geotrans_maskArray.append(geotrans_mask)
        proj_maskArray.append(proj_mask)
    print("行数为:", height_mask)
    print("列数为:", width_mask)
    print("波段数为:", bands_mask)
    print("读取全部掩膜数量为:", len(MaskArray))

    for i in range(len(MaskArray)):
        count = test_fast(MaskArray[i])
        name = i + 1
        if count != 0:
            print("图像保存成功")
            writeTiff(ImgArray[i], geotransArray[i], projArray[i], EndImg + "/%d.tif" % name)
            writeTiff(MaskArray[i], geotrans_maskArray[i], proj_maskArray[i], EndMask + "/%d.tif" % name)

# 循环加速,判断影像中是否含有水体
@jit(nopython=True)
def test_fast(img):
    count = 0
    for row in range(img.shape[0]):
        for col in range(img.shape[1]):
            if img[row, col] != 0:
                count = count + 1
    return count

if __name__ == '__main__':
    # 筛选出包含水体的图像
    TifScreen(r"F:/internship/code/ImagePreprocessing/data/test/img/",
              r"F:/internship/code/ImagePreprocessing/data/test/mask/",
              r"F:/internship/code/ImagePreprocessing/data/test/imgend",
              r"F:/internship/code/ImagePreprocessing/data/test/maskend")

拼接

import os
import sys
try:
    import gdal
except:
    from osgeo import gdal
import numpy as np

import os
from PIL import Image
import matplotlib.pyplot as plt



# 保存tif文件函数
def writeTiff(im_data, im_geotrans, im_proj, path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_Uint16
    else:
        datatype = gdal.GDT_Float32
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
        im_bands, im_height, im_width = im_data.shape
    # 创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
    if (dataset != None):
        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        dataset.SetProjection(im_proj)  # 写入投影
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset


# 像素坐标和地理坐标仿射变换
def CoordTransf(Xpixel, Ypixel, GeoTransform):
    XGeo = GeoTransform[0] + GeoTransform[1] * Xpixel + Ypixel * GeoTransform[2]
    YGeo = GeoTransform[3] + GeoTransform[4] * Xpixel + Ypixel * GeoTransform[5]
    return XGeo, YGeo
    
'''
影像拼接函数
OriTif 原始影像——提供地理坐标和拼接后的影像长宽
TifArrayPath 要拼接的影像
ResultPath 输出的结果影像
RepetitionRate 重叠度
'''
def TifStitch(OriTif, TifArrayPath, ResultPath, RepetitionRate):
    RepetitionRate = float(RepetitionRate)
    print("--------------------拼接影像-----------------------")
    dataset_img = gdal.Open(OriTif)
    width = dataset_img.RasterXSize  # 获取行列数
    height = dataset_img.RasterYSize
    bands = dataset_img.RasterCount  # 获取波段数
    proj = dataset_img.GetProjection()  # 获取投影信息
    geotrans = dataset_img.GetGeoTransform()  # 获取仿射矩阵信息
    ori_img = dataset_img.ReadAsArray(0, 0, width, height)  # 获取数据
    print("波段数为:", bands)

    # 先创建一个空矩阵
    if bands == 1:
        shape = [height, width]
    else:
        shape = [bands, height, width]
    result = np.zeros(shape, dtype='uint8')

    # 读取裁剪后的影像
    OriImgArray = []   #创建队列
    NameArray = []
    imgList = os.listdir(TifArrayPath)  # 读入文件夹
    imgList.sort(key=lambda x: int(x.split('.')[0]))  # 按照数字进行排序后按顺序读取文件夹下的图片
    for TifPath in imgList:
        # 读取tif影像
        # dataset_img = gdal.Open(TifArrayPath + TifPath)
        # width_crop = dataset_img.RasterXSize  # 获取行列数
        # height_crop = dataset_img.RasterYSize
        # bands_crop = dataset_img.RasterCount  # 获取波段数
        # img = dataset_img.ReadAsArray(0, 0, width_crop, height_crop)  # 获取数据

        # 读取png影像
        img = Image.open(TifArrayPath + TifPath)
        height_crop = img.height
        width_crop = img.width
        img = np.array(img)

        OriImgArray.append(img)    # 将影像按顺序存入队列
        name = TifPath.split('.')[0]
        # print(name)
        NameArray.append(name)
    print("读取全部影像数量为:", len(OriImgArray))

    #  行上图像块数目
    RowNum = int((height - height_crop * RepetitionRate) / (height_crop * (1 - RepetitionRate)))
    #  列上图像块数目
    ColumnNum = int((width - width_crop * RepetitionRate) / (width_crop * (1 - RepetitionRate)))
    # 获取图像总数
    sum_img = RowNum * ColumnNum + RowNum + ColumnNum + 1
    print("行影像数为:", RowNum)
    print("列影像数为:", ColumnNum)
    print("图像总数为:", sum_img)

	# 前面读取的是剔除了背景影像的剩余影像,拼接按照图像名称拼接,因此需再创建全为背景的影像,填充影像列表
    # 创建空矩阵
    if bands_crop == 1:
        shape_crop = [height_crop, width_crop]
    else:
        shape_crop = [bands_crop, height_crop, width_crop]
    img_crop = np.zeros(shape_crop)  # 创建空矩阵
    # 创建整体图像列表
    ImgArray = []
    count = 0
    for i in range(sum_img):
        img_name = i + 1
        for j in range(len(OriImgArray)):
            if img_name == int(NameArray[j]):
                image = OriImgArray[j]
                count = count + 1
                break
            else:
                image = img_crop
        ImgArray.append(image)

    print("包含水体的图象数量为:", count)
    print("整个影像列表数量为:", len(ImgArray))


 # 开始赋值
    num = 0
    for i in range(RowNum):
        for j in range(ColumnNum):
            # 如果图像是单波段
            if (bands == 1):
                result[int(i * height_crop * (1 - RepetitionRate)): int(i * height_crop * (1 - RepetitionRate)) + height_crop,
                            int(j * width_crop * (1 - RepetitionRate)): int(j * width_crop * (1 - RepetitionRate)) + width_crop] = ImgArray[num]
            # 如果图像是多波段
            else:
                result[:,
                            int(i * height_crop * (1 - RepetitionRate)): int(i * height_crop * (1 - RepetitionRate)) + height_crop,
                            int(j * width_crop * (1 - RepetitionRate)): int(j * width_crop * (1 - RepetitionRate)) + width_crop] = ImgArray[num]
            num = num + 1
    # 最后一行
    for i in range(RowNum):
        if (bands == 1):
            result[int(i * height_crop * (1 - RepetitionRate)): int(i * height_crop * (1 - RepetitionRate)) + height_crop,
                      (width - width_crop): width] = ImgArray[num]
        else:
            result[:,
                      int(i * height_crop * (1 - RepetitionRate)): int(i * height_crop * (1 - RepetitionRate)) + height_crop,
                      (width - width_crop): width] = ImgArray[num]
        num = num + 1
    # 最后一列
    for j in range(ColumnNum):
        if (bands == 1):
            result[(height - height_crop): height,
                      int(j * width_crop * (1 - RepetitionRate)): int(j * width_crop * (1 - RepetitionRate)) + width_crop] = ImgArray[num]
        else:
            result[:,
                      (height - height_crop): height,
                      int(j * width_crop * (1 - RepetitionRate)): int(j * width_crop * (1 - RepetitionRate)) + width_crop] = ImgArray[num]
        num = num + 1
    # 右下角
    if (bands == 1):
        result[(height - height_crop): height,
                        (width - width_crop): width] = ImgArray[num]
    else:
        result[:,
                    (height - height_crop): height,
                    (width - width_crop): width] = ImgArray[num]
    num = num + 1
    # 生成Tif影像

    writeTiff(result, geotrans, proj, ResultPath)


if __name__ == '__main__':
    # 拼接影像
    TifStitch(r"G:/水体提取/LY/baiyun.tif",
              r"F:/test/mask2/",
              r"F:/test/merge.tif", 0.3)

Tiff转PNG

import os,sys
import cv2
import numpy as np
from skimage import io#使用IO库读取tif图片


def tif_png_transform(file_path_name, bgr_savepath_name):
    img = io.imread(file_path_name)#读取文件名
    img = img / img.max()#使其所有值不大于一
    img = img * 255 - 0.001  # 减去0.001防止变成负整型
    img = img.astype(np.uint8)#强制转换成8位整型
    # img = np.array([img,img,img])
    # img = img.transpose(1,2,0)
    # print(img.shape)  # 显示图片大小和深度
    if len(img.shape) == 3:
        b = img[:, :, 0]  # 读取蓝通道
        g = img[:, :, 1]  # 读取绿通道
        r = img[:, :, 2]  # 读取红通道
        bgr = cv2.merge([r, g, b])  # 通道拼接
        cv2.imwrite(bgr_savepath_name, bgr)#图片存储
        # print("tif转png")
    elif len(img.shape) == 2:
        cv2.imwrite(bgr_savepath_name, img)  # 图片存储


tif_file_path = r'F:/internship/code/ImagePreprocessing/data/test/maskend'# 为tif图片的文件夹路径
png_path = r'F:/internship/code/ImagePreprocessing/data/test/maskpng'
if not os.path.exists(png_path):
    os.makedirs(png_path)
tif_fileList = os.listdir(tif_file_path)
for tif_file in tif_fileList:
    file_path_name = tif_file_path + '/' + tif_file
    png_path_name = png_path + '/' + tif_file.split('.')[0] + '.png' #.png图片的保存路径
    tif_png_transform(file_path_name, png_path_name)

最终结果

在这里插入图片描述
在这里插入图片描述

  • 19
    点赞
  • 148
    收藏
    觉得还不错? 一键收藏
  • 21
    评论
评论 21
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值