随即森林遥感图像语义分割

数据:见参考

import numpy as np
import os
from osgeo import gdal
from sklearn.ensemble import RandomForestClassifier
from matplotlib import pyplot as plt
from sklearn import metrics


def create_mask_from_vector(vector_data_path, cols, rows, geo_transform, projection, target_value=1):
    data_source = gdal.OpenEx(vector_data_path, gdal.OF_VECTOR)
    layer = data_source.GetLayer(0)
    driver = gdal.GetDriverByName('MEM')
    target_ds = driver.Create('', cols, rows, 1, gdal.GDT_UInt16)
    target_ds.SetGeoTransform(geo_transform)
    target_ds.SetProjection(projection)
    gdal.RasterizeLayer(target_ds, [1], layer, burn_values=[target_value])
    return target_ds

def vector_to_raster(file_paths, rows, cols, geo_transform, projection):
    labeled_pixels = np.zeros((rows, cols))
    for i, path in enumerate(file_paths):
        label = i+1
        ds = create_mask_from_vector(path, cols, rows, geo_transform, projection, target_value=label)
        band = ds.GetRasterBand(1)
        labeled_pixels += band.ReadAsArray()
        ds = None
    return labeled_pixels

def write_geotiff(fname, data, geo_transform, projection):
    """create a geotiff file with the given data"""
    driver = gdal.GetDriverByName('GTiff')
    rows, cols = data.shape
    dataset = driver.Create(fname, cols, rows, 1, gdal.GDT_Byte)
    dataset.SetGeoTransform(geo_transform)
    dataset.SetProjection(projection)
    band = dataset.GetRasterBand(1)
    band.WriteArray(data)
    dataset = None  # close the file

raster_data_path = "E:/python/遥感影像/data/image/2298119ene2016recorteTT.tif" 
output_fname = "classification.tif"
train_data_path = "E:/python/遥感影像/data/train/"
validation_data_path = "E:/python/遥感影像/data/test/"

raster_dataset = gdal.Open(raster_data_path, gdal.GA_ReadOnly)
geo_transform = raster_dataset.GetGeoTransform()
proj = raster_dataset.GetProjectionRef()
bands_data = []
for b in range(1, raster_dataset.RasterCount+1):
    band = raster_dataset.GetRasterBand(b)
    bands_data.append(band.ReadAsArray())

bands_data = np.dstack(bands_data)
rows, cols, n_bands = bands_data.shape

files = [f for f in os.listdir(train_data_path) if f.endswith('.shp')]
classes = [f.split('.')[0] for f in files]
shapefiles = [os.path.join(train_data_path, f) for f in files if f.endswith('.shp')]

labeled_pixels = vector_to_raster(shapefiles, rows, cols, geo_transform, proj)
is_train = np.nonzero(labeled_pixels)
training_labels = labeled_pixels[is_train]
training_samples = bands_data[is_train]

classifier = RandomForestClassifier(n_jobs=-1)
classifier.fit(training_samples, training_labels)

n_samples = rows * cols
flat_pixels = bands_data.reshape((n_samples, n_bands))
result = classifier.predict(flat_pixels)
classification = result.reshape((rows, cols))

f = plt.figure()
r = bands_data[:, :, 3]
g = bands_data[:, :, 2]
b = bands_data[:, :, 1]
rgb = np.dstack([r, g, b])
f.add_subplot(1, 2, 1)
plt.imshow(rgb/255)
f.add_subplot(1,2,2)
plt.imshow(classification)

write_geotiff(output_fname, classification, geo_transform, proj)
shapefiles = [os.path.join(validation_data_path, "%s.shp" % c)
for c in classes]
verification_pixels = vector_to_raster(shapefiles, rows, cols,geo_transform, proj)
for_verification = np.nonzero(verification_pixels)
verification_labels = verification_pixels[for_verification]
predicted_labels = classification[for_verification]
print("Confussion matrix:\n%s" %metrics.confusion_matrix(verification_labels, predicted_labels))
target_names = ['Class %s' % s for s in classes]
print("Classification report:\n%s" %metrics.classification_report(verification_labels, predicted_labels,target_names=target_names))
print("Classification accuracy: %f" %metrics.accuracy_score(verification_labels, predicted_labels))

效果图:
在这里插入图片描述
分类报告:

Confussion matrix:
[[ 65   0   0   0]
 [  1  87   0   0]
 [  0   0 180   0]
 [  0   0   0 160]]
Classification report:
              precision    recall  f1-score   support

     Class B       0.98      1.00      0.99        65
     Class C       1.00      0.99      0.99        88
     Class D       1.00      1.00      1.00       180
     Class E       1.00      1.00      1.00       160

    accuracy                           1.00       493
   macro avg       1.00      1.00      1.00       493
weighted avg       1.00      1.00      1.00       493

Classification accuracy: 0.997972

参考:https://zhuanlan.zhihu.com/p/73793438

  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值