YOLO实例分割和变化检测_学习记录(2)

本文介绍了如何使用Python和GDAL库进行影像切片的水平、垂直和对角翻转,以增加数据集多样性,并处理TIF格式的影像和标签。同时,还涉及了标签MASK到labelme格式的转换,以及使用mask2json脚本批量转换的过程。
摘要由CSDN通过智能技术生成

4、影像切片及标签增强

通过水平翻转、垂直翻转和对角翻转实现数据集增多,这一步需要注意保存一份数据备份。

这里影像切片和标签都采用TIF格式进行编辑,同时名称对应且为1,2,3......的序列名称。

这里采用numpy.flip实现影像翻转变化等操作,需要注意的是三波段影像切片和单波段标签影像采用的变化方式不同(且不能用opencv进行处理,会导致单波段标签变成RGB三波段TIF)。

try:
    import gdal
except:
    from osgeo import gdal
import numpy as np
import os
import cv2


#  读取tif数据集
def readTif(fileName, xoff=0, yoff=0, data_width=0, data_height=0):
    dataset = gdal.Open(fileName)
    if dataset == None:
        print(fileName + "文件无法打开")
    #  栅格矩阵的列数
    width = dataset.RasterXSize
    #  栅格矩阵的行数
    height = dataset.RasterYSize
    #  波段数
    bands = dataset.RasterCount
    #  获取数据
    if (data_width == 0 and data_height == 0):
        data_width = width
        data_height = height
    data = dataset.ReadAsArray(xoff, yoff, data_width, data_height)
    #  获取仿射矩阵信息
    geotrans = dataset.GetGeoTransform()
    #  获取投影信息
    proj = dataset.GetProjection()
    return width, height, bands, data, geotrans, proj
#  保存tif文件函数
def writeTiff(im_data, im_geotrans, im_proj, path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
        im_bands, im_height, im_width = im_data.shape
    # 创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
    if (dataset != None):
        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        dataset.SetProjection(im_proj)  # 写入投影
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset

#训练集的数据及标签
train_image_path = r"D:\Data\Dataset\Select\TrainMJ"
train_label_path = r"D:\Data\Dataset\Select\LabelMJ"

#  进行几何变换数据增强
imageList = os.listdir(train_image_path)
labelList = os.listdir(train_label_path)
tran_num = len(imageList) + 1
for i in range(len(imageList)):
    #  图像
    img_file = train_image_path + "\\" + imageList[i]
    im_width, im_height, im_bands, im_data, im_geotrans, im_proj = readTif(img_file)
    #  标签
    label_file = train_label_path + "\\" + labelList[i]
    la_width, la_height, la_bands, la_data, la_geotrans, la_proj = readTif(label_file)

    #  图像水平翻转
    im_data_hor = np.flip(im_data, axis=2)
    hor_path = train_image_path + "\\" + str(tran_num) + imageList[i][-4:]
    writeTiff(im_data_hor, im_geotrans, im_proj, hor_path)
    #  标签水平翻转
    la_data_hor = np.flip(la_data, axis=1)
    hor_path = train_label_path + "\\" + str(tran_num) + imageList[i][-4:]
    writeTiff(la_data_hor, la_geotrans, la_proj, hor_path)
    tran_num += 1

    #  图像垂直翻转
    im_data_vec = np.flip(im_data, axis=1)
    vec_path = train_image_path + "\\" + str(tran_num) + imageList[i][-4:]
    writeTiff(im_data_vec, im_geotrans, im_proj, vec_path)
    #  标签垂直翻转
    la_data_vec = np.flip(la_data, axis=0)
    vec_path = train_label_path + "\\" + str(tran_num) + imageList[i][-4:]
    writeTiff(la_data_vec, la_geotrans, la_proj, vec_path)
    tran_num += 1

    #  图像对角镜像
    im_data_dia = np.flip(im_data_vec, axis=2)
    dia_path = train_image_path + "\\" + str(tran_num) + imageList[i][-4:]
    writeTiff(im_data_dia, im_geotrans, im_proj, dia_path)
    #  标签对角镜像
    la_data_dia = np.flip(la_data_vec, axis=-1)
    dia_path = train_label_path + "\\" + str(tran_num) + imageList[i][-4:]
    writeTiff(la_data_dia, la_geotrans, la_proj, dia_path)
    tran_num += 1

5、标签MASK转labelme格式(json文件)

利用工具转

github地址:https://github.com/guchengxi1994/mask2json

下载代码包后,按照说明进行使用。

和YOLO使用方法类似,通过创建环境,下载requirements,找到mask2json_script,在后面插入运行部分,将路径更改输入其中,即可得到json格式的标签。批量转化就输入文件夹,使用mask2json_script.getJsons,单张转化用getmultishapes。

from convertmask.utils.methods import get_multi_shapes
from convertmask.utils import mask2json_script

imgPath = 'D:\\Data\\Dataset\\Select\\TrainMJ'
maskPath = 'D:\\Data\\Dataset\\Select\\LabelMJ'
savePath = 'D:\\Data\\Dataset\\Select\\JsonMJ'
yamlPath = 'D:\\Data\\Dataset\\Select\\info.yaml'

#单一图像进行mask转json
#get_multi_shapes.getMultiShapes(imgPath, maskPath, savePath, yamlPath)  # with yaml
#get_multi_shapes.getMultiShapes(imgPath, maskPath, savePath)  # without yaml

#folder2json
mask2json_script.getJsons(imgPath, maskPath, savePath, yamlPath)

需要新建一个info.yaml用来给标签赋属性

label_names:
  _background_: 0
  cons: 1

同时由于mask2jaon_script.py里面设置的getJson()方法默认用的jpg,这里需要修改成tif。

'''
lanhuage: python
Descripttion: 
version: beta
Author: xiaoshuyui
Date: 2020-07-10 10:33:39
LastEditors: xiaoshuyui
LastEditTime: 2021-01-05 10:21:49
'''

import glob
import os

from tqdm import tqdm

from convertmask.utils.methods import get_multi_shapes
from convertmask.utils.methods.logger import logger


def getJsons(imgPath, maskPath, savePath, yamlPath=''):
    """
    imgPath: origin image path \n
    maskPath : mask image path \n
    savePath : json file save path \n
    
    >>> getJsons(path-to-your-imgs,path-to-your-maskimgs,path-to-your-jsonfiles) 

    """
    logger.info("currently, only *.jpg supported")

    if os.path.isfile(imgPath):
        get_multi_shapes.getMultiShapes(imgPath, maskPath, savePath, yamlPath)

    elif os.path.isdir(imgPath):
        oriImgs = glob.glob(imgPath + os.sep + '*.tif')
        maskImgs = glob.glob(maskPath + os.sep + '*.tif')
        for i in tqdm(oriImgs):
            i_mask = i.replace(imgPath, maskPath)
            if os.path.exists(i_mask):
                # print(i)
                get_multi_shapes.getMultiShapes(i, i_mask, savePath, yamlPath)
            else:
                logger.warning('corresponding mask image not found!')
                continue
    else:
        logger.error('input error. got [{},{},{},{}]. file maybe missing.'.format(
            imgPath, maskPath, savePath, yamlPath))
    logger.info('Done! See here. {}'.format(savePath))


def getXmls(imgPath, maskPath, savePath):
    logger.info("currently, only *.jpg supported")

    if os.path.isfile(imgPath):
        get_multi_shapes.getMultiObjs_voc(imgPath, maskPath, savePath)
    elif os.path.isdir(imgPath):
        oriImgs = glob.glob(imgPath + os.sep + '*.tif')
        maskImgs = glob.glob(maskPath + os.sep + '*.tif')

        for i in tqdm(oriImgs):
            i_mask = i.replace(imgPath, maskPath)
            # print(i)
            if os.path.exists(i_mask):
                get_multi_shapes.getMultiObjs_voc(i, i_mask, savePath)
            else:
                logger.warning('corresponding mask image not found!')
                continue
    else:
        logger.error('input error. got [{},{},{}]. file maybe missing.'.format(
            imgPath, maskPath, savePath))
    logger.info('Done! See here. {}'.format(savePath))

最终用labelme查看标签是否已经制作完成。

##影像切片拼接##

这里需要特别注意 待拼接切片格式,是PNG还是TIF。这里一般选择TIF。

import os
import sys
try:
    import gdal
except:
    from osgeo import gdal
import numpy as np

import os
from PIL import Image
import matplotlib.pyplot as plt



# 保存tif文件函数
def writeTiff(im_data, im_geotrans, im_proj, path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_Uint16
    else:
        datatype = gdal.GDT_Float32
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
        im_bands, im_height, im_width = im_data.shape
    # 创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
    if (dataset != None):
        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        dataset.SetProjection(im_proj)  # 写入投影
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset


# 像素坐标和地理坐标仿射变换
def CoordTransf(Xpixel, Ypixel, GeoTransform):
    XGeo = GeoTransform[0] + GeoTransform[1] * Xpixel + Ypixel * GeoTransform[2]
    YGeo = GeoTransform[3] + GeoTransform[4] * Xpixel + Ypixel * GeoTransform[5]
    return XGeo, YGeo
    
'''
影像拼接函数
OriTif 原始影像——提供地理坐标和拼接后的影像长宽
TifArrayPath 要拼接的影像
ResultPath 输出的结果影像
RepetitionRate 重叠度
'''
def TifStitch(OriTif, TifArrayPath, ResultPath, RepetitionRate):
    RepetitionRate = float(RepetitionRate)
    print("--------------------拼接影像-----------------------")
    dataset_img = gdal.Open(OriTif)
    width = dataset_img.RasterXSize  # 获取行列数
    height = dataset_img.RasterYSize
    bands = dataset_img.RasterCount  # 获取波段数
    proj = dataset_img.GetProjection()  # 获取投影信息
    geotrans = dataset_img.GetGeoTransform()  # 获取仿射矩阵信息
    ori_img = dataset_img.ReadAsArray(0, 0, width, height)  # 获取数据
    print("波段数为:", bands)

    # 先创建一个空矩阵
    if bands == 1:
        shape = [height, width]
    else:
        shape = [bands, height, width]
    result = np.zeros(shape, dtype='uint8')

    # 读取裁剪后的影像
    OriImgArray = []   #创建队列
    NameArray = []
    imgList = os.listdir(TifArrayPath)  # 读入文件夹
    imgList.sort(key=lambda x: int(x.split('.')[0]))  # 按照数字进行排序后按顺序读取文件夹下的图片
    for TifPath in imgList:
        读取tif影像
        dataset_img = gdal.Open(TifArrayPath + TifPath)
        width_crop = dataset_img.RasterXSize  # 获取行列数
        height_crop = dataset_img.RasterYSize
        bands_crop = dataset_img.RasterCount  # 获取波段数
        img = dataset_img.ReadAsArray(0, 0, width_crop, height_crop)  # 获取数据

        # 读取png影像
        # img = Image.open(TifArrayPath + TifPath)
        # height_crop = img.height
        # width_crop = img.width
        # img = np.array(img)

        OriImgArray.append(img)    # 将影像按顺序存入队列
        name = TifPath.split('.')[0]
        # print(name)
        NameArray.append(name)
    print("读取全部影像数量为:", len(OriImgArray))

    #  行上图像块数目
    RowNum = int((height - height_crop * RepetitionRate) / (height_crop * (1 - RepetitionRate)))
    #  列上图像块数目
    ColumnNum = int((width - width_crop * RepetitionRate) / (width_crop * (1 - RepetitionRate)))
    # 获取图像总数
    sum_img = RowNum * ColumnNum + RowNum + ColumnNum + 1
    print("行影像数为:", RowNum)
    print("列影像数为:", ColumnNum)
    print("图像总数为:", sum_img)

	# 前面读取的是剔除了背景影像的剩余影像,拼接按照图像名称拼接,因此需再创建全为背景的影像,填充影像列表
    # 创建空矩阵
    if bands_crop == 1:
        shape_crop = [height_crop, width_crop]
    else:
        shape_crop = [bands_crop, height_crop, width_crop]
    img_crop = np.zeros(shape_crop)  # 创建空矩阵
    # 创建整体图像列表
    ImgArray = []
    count = 0
    for i in range(sum_img):
        img_name = i + 1
        for j in range(len(OriImgArray)):
            if img_name == int(NameArray[j]):
                image = OriImgArray[j]
                count = count + 1
                break
            else:
                image = img_crop
        ImgArray.append(image)

    print("含目标图像数量为:", count)
    print("整个影像列表数量为:", len(ImgArray))


 # 开始赋值
    num = 0
    for i in range(RowNum):
        for j in range(ColumnNum):
            # 如果图像是单波段
            if (bands == 1):
                result[int(i * height_crop * (1 - RepetitionRate)): int(i * height_crop * (1 - RepetitionRate)) + height_crop,
                            int(j * width_crop * (1 - RepetitionRate)): int(j * width_crop * (1 - RepetitionRate)) + width_crop] = ImgArray[num]
            # 如果图像是多波段
            else:
                result[:,
                            int(i * height_crop * (1 - RepetitionRate)): int(i * height_crop * (1 - RepetitionRate)) + height_crop,
                            int(j * width_crop * (1 - RepetitionRate)): int(j * width_crop * (1 - RepetitionRate)) + width_crop] = ImgArray[num]
            num = num + 1
    # 最后一行
    for i in range(RowNum):
        if (bands == 1):
            result[int(i * height_crop * (1 - RepetitionRate)): int(i * height_crop * (1 - RepetitionRate)) + height_crop,
                      (width - width_crop): width] = ImgArray[num]
        else:
            result[:,
                      int(i * height_crop * (1 - RepetitionRate)): int(i * height_crop * (1 - RepetitionRate)) + height_crop,
                      (width - width_crop): width] = ImgArray[num]
        num = num + 1
    # 最后一列
    for j in range(ColumnNum):
        if (bands == 1):
            result[(height - height_crop): height,
                      int(j * width_crop * (1 - RepetitionRate)): int(j * width_crop * (1 - RepetitionRate)) + width_crop] = ImgArray[num]
        else:
            result[:,
                      (height - height_crop): height,
                      int(j * width_crop * (1 - RepetitionRate)): int(j * width_crop * (1 - RepetitionRate)) + width_crop] = ImgArray[num]
        num = num + 1
    # 右下角
    if (bands == 1):
        result[(height - height_crop): height,
                        (width - width_crop): width] = ImgArray[num]
    else:
        result[:,
                    (height - height_crop): height,
                    (width - width_crop): width] = ImgArray[num]
    num = num + 1
    # 生成Tif影像

    writeTiff(result, geotrans, proj, ResultPath)


if __name__ == '__main__':
    # 拼接影像、原始图像为
    TifStitch(r"F:/OriTIF.tif",
              r"F:/test/TrainMJ/",
              r"F:/test/merge.tif", 0)

  • 9
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值