4、影像切片及标签增强
通过水平翻转、垂直翻转和对角翻转实现数据集增多,这一步需要注意保存一份数据备份。
这里影像切片和标签都采用TIF格式进行编辑,同时名称对应且为1,2,3......的序列名称。
这里采用numpy.flip实现影像翻转变化等操作,需要注意的是三波段影像切片和单波段标签影像采用的变化方式不同(且不能用opencv进行处理,会导致单波段标签变成RGB三波段TIF)。
try:
import gdal
except:
from osgeo import gdal
import numpy as np
import os
import cv2
# 读取tif数据集
def readTif(fileName, xoff=0, yoff=0, data_width=0, data_height=0):
dataset = gdal.Open(fileName)
if dataset == None:
print(fileName + "文件无法打开")
# 栅格矩阵的列数
width = dataset.RasterXSize
# 栅格矩阵的行数
height = dataset.RasterYSize
# 波段数
bands = dataset.RasterCount
# 获取数据
if (data_width == 0 and data_height == 0):
data_width = width
data_height = height
data = dataset.ReadAsArray(xoff, yoff, data_width, data_height)
# 获取仿射矩阵信息
geotrans = dataset.GetGeoTransform()
# 获取投影信息
proj = dataset.GetProjection()
return width, height, bands, data, geotrans, proj
# 保存tif文件函数
def writeTiff(im_data, im_geotrans, im_proj, path):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
im_bands, im_height, im_width = im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
if (dataset != None):
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
#训练集的数据及标签
train_image_path = r"D:\Data\Dataset\Select\TrainMJ"
train_label_path = r"D:\Data\Dataset\Select\LabelMJ"
# 进行几何变换数据增强
imageList = os.listdir(train_image_path)
labelList = os.listdir(train_label_path)
tran_num = len(imageList) + 1
for i in range(len(imageList)):
# 图像
img_file = train_image_path + "\\" + imageList[i]
im_width, im_height, im_bands, im_data, im_geotrans, im_proj = readTif(img_file)
# 标签
label_file = train_label_path + "\\" + labelList[i]
la_width, la_height, la_bands, la_data, la_geotrans, la_proj = readTif(label_file)
# 图像水平翻转
im_data_hor = np.flip(im_data, axis=2)
hor_path = train_image_path + "\\" + str(tran_num) + imageList[i][-4:]
writeTiff(im_data_hor, im_geotrans, im_proj, hor_path)
# 标签水平翻转
la_data_hor = np.flip(la_data, axis=1)
hor_path = train_label_path + "\\" + str(tran_num) + imageList[i][-4:]
writeTiff(la_data_hor, la_geotrans, la_proj, hor_path)
tran_num += 1
# 图像垂直翻转
im_data_vec = np.flip(im_data, axis=1)
vec_path = train_image_path + "\\" + str(tran_num) + imageList[i][-4:]
writeTiff(im_data_vec, im_geotrans, im_proj, vec_path)
# 标签垂直翻转
la_data_vec = np.flip(la_data, axis=0)
vec_path = train_label_path + "\\" + str(tran_num) + imageList[i][-4:]
writeTiff(la_data_vec, la_geotrans, la_proj, vec_path)
tran_num += 1
# 图像对角镜像
im_data_dia = np.flip(im_data_vec, axis=2)
dia_path = train_image_path + "\\" + str(tran_num) + imageList[i][-4:]
writeTiff(im_data_dia, im_geotrans, im_proj, dia_path)
# 标签对角镜像
la_data_dia = np.flip(la_data_vec, axis=-1)
dia_path = train_label_path + "\\" + str(tran_num) + imageList[i][-4:]
writeTiff(la_data_dia, la_geotrans, la_proj, dia_path)
tran_num += 1
5、标签MASK转labelme格式(json文件)
利用工具转
github地址:https://github.com/guchengxi1994/mask2json
下载代码包后,按照说明进行使用。
和YOLO使用方法类似,通过创建环境,下载requirements,找到mask2json_script,在后面插入运行部分,将路径更改输入其中,即可得到json格式的标签。批量转化就输入文件夹,使用mask2json_script.getJsons,单张转化用getmultishapes。
from convertmask.utils.methods import get_multi_shapes
from convertmask.utils import mask2json_script
imgPath = 'D:\\Data\\Dataset\\Select\\TrainMJ'
maskPath = 'D:\\Data\\Dataset\\Select\\LabelMJ'
savePath = 'D:\\Data\\Dataset\\Select\\JsonMJ'
yamlPath = 'D:\\Data\\Dataset\\Select\\info.yaml'
#单一图像进行mask转json
#get_multi_shapes.getMultiShapes(imgPath, maskPath, savePath, yamlPath) # with yaml
#get_multi_shapes.getMultiShapes(imgPath, maskPath, savePath) # without yaml
#folder2json
mask2json_script.getJsons(imgPath, maskPath, savePath, yamlPath)
需要新建一个info.yaml用来给标签赋属性
label_names:
_background_: 0
cons: 1
同时由于mask2jaon_script.py里面设置的getJson()方法默认用的jpg,这里需要修改成tif。
'''
lanhuage: python
Descripttion:
version: beta
Author: xiaoshuyui
Date: 2020-07-10 10:33:39
LastEditors: xiaoshuyui
LastEditTime: 2021-01-05 10:21:49
'''
import glob
import os
from tqdm import tqdm
from convertmask.utils.methods import get_multi_shapes
from convertmask.utils.methods.logger import logger
def getJsons(imgPath, maskPath, savePath, yamlPath=''):
"""
imgPath: origin image path \n
maskPath : mask image path \n
savePath : json file save path \n
>>> getJsons(path-to-your-imgs,path-to-your-maskimgs,path-to-your-jsonfiles)
"""
logger.info("currently, only *.jpg supported")
if os.path.isfile(imgPath):
get_multi_shapes.getMultiShapes(imgPath, maskPath, savePath, yamlPath)
elif os.path.isdir(imgPath):
oriImgs = glob.glob(imgPath + os.sep + '*.tif')
maskImgs = glob.glob(maskPath + os.sep + '*.tif')
for i in tqdm(oriImgs):
i_mask = i.replace(imgPath, maskPath)
if os.path.exists(i_mask):
# print(i)
get_multi_shapes.getMultiShapes(i, i_mask, savePath, yamlPath)
else:
logger.warning('corresponding mask image not found!')
continue
else:
logger.error('input error. got [{},{},{},{}]. file maybe missing.'.format(
imgPath, maskPath, savePath, yamlPath))
logger.info('Done! See here. {}'.format(savePath))
def getXmls(imgPath, maskPath, savePath):
logger.info("currently, only *.jpg supported")
if os.path.isfile(imgPath):
get_multi_shapes.getMultiObjs_voc(imgPath, maskPath, savePath)
elif os.path.isdir(imgPath):
oriImgs = glob.glob(imgPath + os.sep + '*.tif')
maskImgs = glob.glob(maskPath + os.sep + '*.tif')
for i in tqdm(oriImgs):
i_mask = i.replace(imgPath, maskPath)
# print(i)
if os.path.exists(i_mask):
get_multi_shapes.getMultiObjs_voc(i, i_mask, savePath)
else:
logger.warning('corresponding mask image not found!')
continue
else:
logger.error('input error. got [{},{},{}]. file maybe missing.'.format(
imgPath, maskPath, savePath))
logger.info('Done! See here. {}'.format(savePath))
最终用labelme查看标签是否已经制作完成。
##影像切片拼接##
这里需要特别注意 待拼接切片格式,是PNG还是TIF。这里一般选择TIF。
import os
import sys
try:
import gdal
except:
from osgeo import gdal
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
# 保存tif文件函数
def writeTiff(im_data, im_geotrans, im_proj, path):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_Uint16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
im_bands, im_height, im_width = im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
if (dataset != None):
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
# 像素坐标和地理坐标仿射变换
def CoordTransf(Xpixel, Ypixel, GeoTransform):
XGeo = GeoTransform[0] + GeoTransform[1] * Xpixel + Ypixel * GeoTransform[2]
YGeo = GeoTransform[3] + GeoTransform[4] * Xpixel + Ypixel * GeoTransform[5]
return XGeo, YGeo
'''
影像拼接函数
OriTif 原始影像——提供地理坐标和拼接后的影像长宽
TifArrayPath 要拼接的影像
ResultPath 输出的结果影像
RepetitionRate 重叠度
'''
def TifStitch(OriTif, TifArrayPath, ResultPath, RepetitionRate):
RepetitionRate = float(RepetitionRate)
print("--------------------拼接影像-----------------------")
dataset_img = gdal.Open(OriTif)
width = dataset_img.RasterXSize # 获取行列数
height = dataset_img.RasterYSize
bands = dataset_img.RasterCount # 获取波段数
proj = dataset_img.GetProjection() # 获取投影信息
geotrans = dataset_img.GetGeoTransform() # 获取仿射矩阵信息
ori_img = dataset_img.ReadAsArray(0, 0, width, height) # 获取数据
print("波段数为:", bands)
# 先创建一个空矩阵
if bands == 1:
shape = [height, width]
else:
shape = [bands, height, width]
result = np.zeros(shape, dtype='uint8')
# 读取裁剪后的影像
OriImgArray = [] #创建队列
NameArray = []
imgList = os.listdir(TifArrayPath) # 读入文件夹
imgList.sort(key=lambda x: int(x.split('.')[0])) # 按照数字进行排序后按顺序读取文件夹下的图片
for TifPath in imgList:
读取tif影像
dataset_img = gdal.Open(TifArrayPath + TifPath)
width_crop = dataset_img.RasterXSize # 获取行列数
height_crop = dataset_img.RasterYSize
bands_crop = dataset_img.RasterCount # 获取波段数
img = dataset_img.ReadAsArray(0, 0, width_crop, height_crop) # 获取数据
# 读取png影像
# img = Image.open(TifArrayPath + TifPath)
# height_crop = img.height
# width_crop = img.width
# img = np.array(img)
OriImgArray.append(img) # 将影像按顺序存入队列
name = TifPath.split('.')[0]
# print(name)
NameArray.append(name)
print("读取全部影像数量为:", len(OriImgArray))
# 行上图像块数目
RowNum = int((height - height_crop * RepetitionRate) / (height_crop * (1 - RepetitionRate)))
# 列上图像块数目
ColumnNum = int((width - width_crop * RepetitionRate) / (width_crop * (1 - RepetitionRate)))
# 获取图像总数
sum_img = RowNum * ColumnNum + RowNum + ColumnNum + 1
print("行影像数为:", RowNum)
print("列影像数为:", ColumnNum)
print("图像总数为:", sum_img)
# 前面读取的是剔除了背景影像的剩余影像,拼接按照图像名称拼接,因此需再创建全为背景的影像,填充影像列表
# 创建空矩阵
if bands_crop == 1:
shape_crop = [height_crop, width_crop]
else:
shape_crop = [bands_crop, height_crop, width_crop]
img_crop = np.zeros(shape_crop) # 创建空矩阵
# 创建整体图像列表
ImgArray = []
count = 0
for i in range(sum_img):
img_name = i + 1
for j in range(len(OriImgArray)):
if img_name == int(NameArray[j]):
image = OriImgArray[j]
count = count + 1
break
else:
image = img_crop
ImgArray.append(image)
print("含目标图像数量为:", count)
print("整个影像列表数量为:", len(ImgArray))
# 开始赋值
num = 0
for i in range(RowNum):
for j in range(ColumnNum):
# 如果图像是单波段
if (bands == 1):
result[int(i * height_crop * (1 - RepetitionRate)): int(i * height_crop * (1 - RepetitionRate)) + height_crop,
int(j * width_crop * (1 - RepetitionRate)): int(j * width_crop * (1 - RepetitionRate)) + width_crop] = ImgArray[num]
# 如果图像是多波段
else:
result[:,
int(i * height_crop * (1 - RepetitionRate)): int(i * height_crop * (1 - RepetitionRate)) + height_crop,
int(j * width_crop * (1 - RepetitionRate)): int(j * width_crop * (1 - RepetitionRate)) + width_crop] = ImgArray[num]
num = num + 1
# 最后一行
for i in range(RowNum):
if (bands == 1):
result[int(i * height_crop * (1 - RepetitionRate)): int(i * height_crop * (1 - RepetitionRate)) + height_crop,
(width - width_crop): width] = ImgArray[num]
else:
result[:,
int(i * height_crop * (1 - RepetitionRate)): int(i * height_crop * (1 - RepetitionRate)) + height_crop,
(width - width_crop): width] = ImgArray[num]
num = num + 1
# 最后一列
for j in range(ColumnNum):
if (bands == 1):
result[(height - height_crop): height,
int(j * width_crop * (1 - RepetitionRate)): int(j * width_crop * (1 - RepetitionRate)) + width_crop] = ImgArray[num]
else:
result[:,
(height - height_crop): height,
int(j * width_crop * (1 - RepetitionRate)): int(j * width_crop * (1 - RepetitionRate)) + width_crop] = ImgArray[num]
num = num + 1
# 右下角
if (bands == 1):
result[(height - height_crop): height,
(width - width_crop): width] = ImgArray[num]
else:
result[:,
(height - height_crop): height,
(width - width_crop): width] = ImgArray[num]
num = num + 1
# 生成Tif影像
writeTiff(result, geotrans, proj, ResultPath)
if __name__ == '__main__':
# 拼接影像、原始图像为
TifStitch(r"F:/OriTIF.tif",
r"F:/test/TrainMJ/",
r"F:/test/merge.tif", 0)