遥感影像通常是大幅的,因此直接输入predict模型,会out of memory,通常的处理方案是:
先剪裁-预测小幅影像-拼接
一、剪裁与拼接函数实现
以示例影像为例:4240*3120剪裁成256*256:
#剪裁函数
#剪裁函数
def crop_image(image_path, output_path,crop_size=(256, 256), background_color=0):
image = tf.imread(image_path)
height, width, bands = image.shape
new_width = (width // crop_size[0] + 1) * crop_size[0]
new_height = (height // crop_size[1] + 1) * crop_size[1]
new_image = np.zeros((new_height, new_width, bands))
new_image[:height, :width, :] = image
cropped_images = []
for i in range(0, new_height, crop_size[1]):
for j in range(0, new_width, crop_size[0]):
cropped_image = new_image[i:i +crop_size[1], j:j + crop_size[0], :]
cropped_images.append(cropped_image)
tf.imwrite(os.path.join(output_path, f"cropped_{i}_{j}.tif"),cropped_image)
return cropped_images
#拼接函数
def stitch_images(predicted_images,original_width, original_height):
new_width = (original_width // 256 + 1) * 256
new_height = (original_height // 256 + 1) * 256
stitched_image = np.zeros((new_height, new_width))
i_index = 0
for i in range(0, stitched_image.shape[0], 256):
j_index = 0
for j in range(0,stitched_image.shape[1], 256):
predicted_image =predicted_images[i_index * (stitched_image.shape[1] // 256) + j_index]
stitched_image[i:i +predicted_image.shape[0], j:j + predicted_image.shape[1]] = predicted_image
j_index += 1
i_index += 1
return stitched_image[:original_height, :original_width]
但是这样做有可能会出现接缝问题:
图中这样的接缝我们可以通过padding去除,思路我参考的是:
采用影像分块叠加策略,逐块预测,块与块之间按一定的步长进行重叠,最后有价值的预测结果只取块的中间一圈。这样就避免了块边缘区域的预测不准造成的拼接痕迹
遥感深度学习预测过程中接边问题解决_能量鸣新的博客-CSDN博客
二、剪裁与拼接接缝去除
修改剪裁与拼接函数,实现接缝去除:
def crop_image(image_path, output_path, crop_size=(256, 256), padding=32, background_color=0):
image = tf.imread(image_path)
height, width, bands = image.shape
new_width = (width // (crop_size[0] - padding) + 1) * (crop_size[0] - padding) + padding * 2
new_height = (height // (crop_size[1] - padding) + 1) * (crop_size[1] - padding) + padding * 2
new_image = np.zeros((new_height, new_width, bands))
new_image[padding:padding + height, padding:padding + width, :] = image
cropped_images = []
for i in range(padding, new_height - padding, crop_size[1] - padding):
for j in range(padding, new_width - padding, crop_size[0] - padding):
cropped_image = new_image[i - padding:i + crop_size[1] + padding, j - padding:j + crop_size[0] + padding, :]
cropped_images.append(cropped_image)
# tf.imwrite(os.path.join(output_path, f"cropped_{i}_{j}.tif"), cropped_image)
return cropped_images
def stitch_images(predicted_images, original_width, original_height, padding=32):
new_width = (original_width // (256 - padding) + 1) * (256 - padding)
new_height = (original_height // (256 - padding) + 1) * (256 - padding)
stitched_image = np.zeros((new_height, new_width))
i_index = 0
for i in range(0, stitched_image.shape[0], 256 - padding):
j_index = 0
for j in range(0, stitched_image.shape[1], 256 - padding):
predicted_image = predicted_images[i_index * (stitched_image.shape[1] // (256 - padding)) + j_index]
stitched_image[i:i + predicted_image.shape[0] - padding * 2,
j:j + predicted_image.shape[1] - padding * 2] = predicted_image[padding:-padding, padding:-padding]
j_index += 1
i_index += 1
return stitched_image[:original_height, :original_width]
或者代码2:
def crop_image(image_path, output_path, crop_size=(256, 256), padding=32, background_color=0):
image = tf.imread(image_path)
height, width, bands = image.shape
new_width = (width // (crop_size[0]) + 1) * (crop_size[0]) + padding * 2
new_height = (height // (crop_size[1]) + 1) * (crop_size[1]) + padding * 2
new_image = np.zeros((new_height, new_width, bands))
new_image[padding:padding + height, padding:padding + width, :] = image
cropped_images = []
for i in range(padding, new_height - padding, crop_size[0]):
for j in range(padding, new_width - padding, crop_size[1]):
cropped_image = new_image[i - padding:i + crop_size[1] + padding, j - padding:j + crop_size[0] + padding, :]
cropped_images.append(cropped_image)
# tf.imwrite(os.path.join(output_path, f"cropped_{i}_{j}.tif"), cropped_image)
return cropped_images
def stitch_images(predicted_images, original_width, original_height, padding=32):
crop_size = (256, 256)
new_width = (original_width // (crop_size[0]) + 1) * (crop_size[0]) + padding * 2
new_height = (original_height // (crop_size[1]) + 1) * (crop_size[1]) + padding * 2
stitched_image = np.zeros((new_height, new_width))
i_index = 0
for i in range(0, stitched_image.shape[0] -padding*2, 256):
j_index = 0
for j in range(0, stitched_image.shape[1]- padding*2, 256):
predicted_image = predicted_images[i_index * (stitched_image.shape[1] // (256)) + j_index]
# predicted_image = predicted_images[k]
stitched_image[padding+i:i + predicted_image.shape[0] - padding,
padding+j:j + predicted_image.shape[1] - padding] = predicted_image[padding:-padding, padding:-padding]
j_index += 1
i_index += 1
return stitched_image[padding:original_height, padding:original_width]
接缝去除成功!