# -*- coding: utf-8 -*-
import os, json
import cv2
from osgeo import gdal
import numpy as np
from osgeo import ogr, gdal, osr
from shapely.geometry import box, shape
from shapely.geometry.polygon import Polygon
import collections
import datetime
import geopandas as gpd
import shutil
import glob
from PIL import Image, ImageDraw
from pycococreatortools import pycococreatortools
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
# del dataset
return im_width, im_height, im_proj, im_geotrans, im_data, dataset
def write_img(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
def get_boundary_points(geom, geo_transform, x_res, y_res):
points = [] # store points in real world
pixels = [] # store pixels in images
x_pixels = []
y_pixels = []
polygonpoints = []
#feature_type = geom.GetGeometryName() #feature_type: LINEARRING
for j in range(geom.GetPointCount()):
px = geom.GetX(j)
py = geom.GetY(j)
points.append((px, py))
for p in points:
point = []
new_pixel_x = float((p[0] - geo_transform[0]) / x_res) #float
new_pixel_x = round(new_pixel_x, 4)
new_pixel_y = float((p[1] - geo_transform[3]) / y_res)
new_pixel_y = round(new_pixel_y, 4)
x_pixels.append(new_pixel_x)
y_pixels.append(new_pixel_y)
point.append(new_pixel_x)
point.append(new_pixel_y)
pixels.append(point)
polygonpoints.append((new_pixel_x, new_pixel_y))
return x_pixels, y_pixels, pixels, polygonpoints
# 进行归一化操作
def convert(size, box): # size:(原图w,原图h) , box:(xmin,xmax,ymin,ymax)
dw = 1./size[0] # 1/w
dh = 1./size[1] # 1/h
x = (box[0] + box[1])/2.0 # 物体在图中的中心点x坐标
y = (box[2] + box[3])/2.0 # 物体在图中的中心点y坐标
w = box[1] - box[0] # 物体实际像素宽度
h = box[3] - box[2] # 物体实际像素高度
x = x*dw # 物体中心点x的坐标比(相当于 x/原图w)
w = w*dw # 物体宽度的宽度比(相当于 w/原图w)
y = y*dh # 物体中心点y的坐标比(相当于 y/原图h)
h = h*dh # 物体宽度的宽度比(相当于 h/原图h)
x = round(x, 4)
w = round(w, 4)
y = round(y, 4)
h = round(h, 4)
return (x, y, w, h) # 返回 相对于原图的物体中心点的x坐标比,y坐标比,宽度比,高度比,取值范围[0-1]
def data2YoloAndCoco(shapefile_path, tif_path, yolo_txt_path):
full_name = os.path.split(tif_path)[1]
name = full_name[:-4]
#print('name:', name)
# 打开Shapefile文件
shapefile_ds = ogr.Open(shapefile_path)
shapefile_layer = shapefile_ds.GetLayer()
if shapefile_ds is None:
print("无法打开Shapefile文件")
return
# 打开TIFF文件获取地理转换信息
tif_ds = gdal.Open(tif_path)
geo_transform = tif_ds.GetGeoTransform()
# 分辨率
x_res = geo_transform[1]
y_res = geo_transform[5]
#if tif_ds is None:
# print("无法打开TIFF文件")
# return
width = tif_ds.RasterXSize
height = tif_ds.RasterYSize
#txt file to save xy labels
yolo_label_path = os.path.join(yolo_txt_path, name + ".txt")
txt = open(yolo_label_path, 'w')
#count = 1
# 遍历每个要素
for feature in shapefile_layer:
geometry = feature.GetGeometryRef()
ring = geometry.GetGeometryRef(0) # Assuming it's a polygon with a single exterior ring
#cls_id = feature.GetField("class")
x_pixels, y_pixels, point_pixel, polygonpoints = get_boundary_points(ring, geo_transform, x_res, y_res) # get xy of each feature
#print('x_pixels:', x_pixels)
#get bounding box coordinates, image pixels
if len(x_pixels)>0 and len(y_pixels)>0:
minx = min(x_pixels)
maxx = max(x_pixels)
miny = min(y_pixels)
maxy = max(y_pixels)
bndbox = [minx, maxx, miny, maxy]
# w = 宽, h = 高, b= bndbox的数组 = ['xmin','xmax','ymin','ymax']
bb = convert((width, height), bndbox) #return (x, y, w, h)
#txt.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
txt.write("0" + " " + " ".join([str(a) for a in bb]) + '\n')
#exit(0)
txt.close()
# 关闭数据源
shapefile_ds = None
tif_ds = None
if __name__ == "__main__":
#root_tiff_folder = './Train-7bands-images/'
#root_shpf_folder = './Train-shp-labels/'
#out_json_file = './data_7bands-64/train.json'
root_tiff_folder = './convertdata/clip_image_train/'
root_shpf_folder = './convertdata/clip_shp_train/'
yolo_txt_path = './convertdata/output_yolo_train/'
#image_id = 0
for sitname in os.listdir(root_shpf_folder):
for regionn in os.listdir(os.path.join(root_shpf_folder, sitname)):
#tiff_folder = './data4test/test_image_128/'
#shpf_folder = './data4test/test_label_128/'
shpf_folder = os.path.join(root_shpf_folder, sitname, regionn)
tiff_folder = os.path.join(root_tiff_folder, sitname, regionn)
# # 遍历每个shp文件
#for image_id, shpfile_path in enumerate(glob.glob(os.path.join(shpf_folder, '*.shp'))):
for shpfile_path in glob.glob(os.path.join(shpf_folder, '*.shp')):
print('processing shpfile_path:', shpfile_path)
#shpfile_path = os.path.join(shpf_folder, shpfile)
shpfile_name = os.path.basename(shpfile_path)
shpfile_name, shpfile_ext = os.path.splitext(shpfile_name)
tiffile_name = shpfile_name + '.tif'
tiffile_path = os.path.join(tiff_folder, tiffile_name)
print('processing tiffile_path:', tiffile_path)
#copy tiff image files to a new folder
#tiffile_to_path = os.path.join('./convertdata/val_img_all', tiffile_name)
#shutil.copy(tiffile_path, tiffile_to_path)
data2YoloAndCoco(shpfile_path, tiffile_path, yolo_txt_path)
#image_id += 1
#exit(0)
02-22
2101
05-26
2106
12-27
2392