随笔小杂记(三)——将遥感大图随机分割成小图作为训练集
闲谈
这个是我师兄给我的代码,防止以后可能还要用,在这里也做个备份~
代码
#coding=utf-8
import numpy as np
import cv2
import tqdm
import os
from tqdm import trange
import random
img_w = 2048
img_h = img_w
image_sets = ['1_fake.tif', '2_fake.tif', '3_fake.tif', '4_fake.tif']#原图的名字
label_sets = ['1_label.tif', '2_label.tif', '3_label.tif', '4_label.tif']#mndwi做label的图像名字
dataroot = 'E:\\data\\SCI(7yue)\\New_dict_src2MNDWI_label(7yue)\\fake_mndwi_label'
outputpath = 'E:\\data\\SCI(7yue)\\New_dict_src2MNDWI_label(7yue)\\fake_MNDWI_label_slide'
def gamma_transform(img, gamma):
gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(2048)]
gamma_table = np.round(np.array(gamma_table)).astype(np.uint8)
return cv2.LUT(img, gamma_table)
def random_gamma_transform(img, gamma_vari):
log_gamma_vari = np.log(gamma_vari)
alpha = np.random.uniform(-log_gamma_vari, log_gamma_vari)
gamma = np.exp(alpha)
return gamma_transform(img, gamma)
def rotate(xb, yb, angle):
M_rotate = cv2.getRotationMatrix2D((img_w / 2, img_h / 2), angle, 1)
xb = cv2.warpAffine(xb, M_rotate, (img_w, img_h))
yb = cv2.warpAffine(yb, M_rotate, (img_w, img_h))
return xb, yb
def blur(img):
img = cv2.blur(img, (3, 3))
return img
def add_noise(img):
for i in range(200): #添加点噪声
temp_x = np.random.randint(0, img.shape[0])
temp_y = np.random.randint(0, img.shape[1])
img[temp_x][temp_y] = 255
return img
def data_augment(xb, yb):
if np.random.random() < 0.25:
xb, yb = rotate(xb, yb, 90)
if np.random.random() < 0.25:
xb, yb = rotate(xb, yb, 180)
if np.random.random() < 0.25:
xb, yb = rotate(xb, yb, 270)
if np.random.random() < 0.25:
xb = cv2.flip(xb, 1) # flipcode > 0:沿y轴翻转
yb = cv2.flip(yb, 1)
if np.random.random() < 0.25:
xb = random_gamma_transform(xb, 1.0)
if np.random.random() < 0.25:
xb = blur(xb)
if np.random.random() < 0.2:
xb = add_noise(xb)
return xb, yb
def creat_dataset(image_num, mode='original'):
print('creating dataset...')
image_each = image_num / len(image_sets)
g_count = 0
lenn = len(image_sets)
for i in trange(lenn):
count = 0
# print(dataroot)
# print(dataroot + '\\' + image_sets[i]+"111")
src_img = cv2.imread(dataroot + '\\' +
image_sets[i]) # 3 channels
# print(src_img.shape)
# print(dataroot + '\\' + label_sets[i]+'222')
label_img = cv2.imread(dataroot + '\\' + label_sets[i],
cv2.IMREAD_GRAYSCALE) # single channel
# label_img=cv2.cvtColor(label_img,cv2.COLOR_BGR2GRAY)#转换为灰度图
X_height, X_width, _ = src_img.shape
while count < image_each:
random_width = random.randint(0, X_width - img_w - 1)
random_height = random.randint(0, X_height - img_h - 1)
# print("random_width", random_width, "--"," random_height", random_height)
src_roi = src_img[random_height:random_height +
img_h, random_width:random_width + img_w, :]
# print('src: ', random_height,"--", random_height +img_h," ", random_width, "--", random_width + img_w)
label_roi = label_img[random_height:random_height +
img_h, random_width:random_width + img_w]
# print('label: ', random_height,"--", random_height +img_h," ", random_width, "--", random_width + img_w)
# print(src_roi.shape)
# print(label_roi.shape)
if (np.sum(label_roi==0)>(img_h*img_w)*0.08 and np.sum(label_roi==255)>(img_h*img_w)*0.08):#加入标签判别,全水和全非水的剔除
if mode == 'augment':
src_roi, label_roi = data_augment(src_roi, label_roi)
# visualize = np.zeros((2048, 2048)).astype(np.uint8)
# visualize = label_roi * 50
# if os.path.exists(dataroot + '\\differentsize\\20484\\visualize') == False:
# os.makedirs(dataroot + '\\differentsize\\20484\\visualize')
if os.path.exists(outputpath + '\\differentsize\\2048\\src') == False:
os.makedirs(outputpath + '\\differentsize\\2048\\src')
if os.path.exists(outputpath + '\\differentsize\\2048\\label') == False:
os.makedirs(outputpath + '\\differentsize\\2048\\label')
# cv2.imwrite(
# (dataroot + '\\differentsize\\20484\\visualize\\' + '%d_vis.png' % g_count),
# visualize)
#用于原来的DeeplabV3+训练的命名
# cv2.imwrite(
# (outputpath + '\\differentsize\\2048\\src\\' + '%d_leftImg8bit.png' % g_count),
# src_roi)
# cv2.imwrite((outputpath + '\\differentsize\\2048\\label\\' +
# '%d_gtFine_labelIds.png' % g_count), label_roi)
#生产通用命名
cv2.imwrite(
(outputpath + '\\differentsize\\2048\\src\\' + '%d_src.png' % g_count),
src_roi)
cv2.imwrite((outputpath + '\\differentsize\\2048\\label\\' +
'%d_label.png' % g_count), label_roi)
print('saving:',
outputpath + '\\differentsize\\2048\\src\\' + '%d_src.png' % g_count)
count += 1
g_count += 1
else:
# print('标签出错')
pass
def main():
creat_dataset(image_num=800, mode='original')
if __name__ == "__main__":
print(dataroot)
main()