数据:见参考
import numpy as np
import os
from osgeo import gdal
from sklearn.ensemble import RandomForestClassifier
from matplotlib import pyplot as plt
from sklearn import metrics
def create_mask_from_vector(vector_data_path, cols, rows, geo_transform, projection, target_value=1):
data_source = gdal.OpenEx(vector_data_path, gdal.OF_VECTOR)
layer = data_source.GetLayer(0)
driver = gdal.GetDriverByName('MEM')
target_ds = driver.Create('', cols, rows, 1, gdal.GDT_UInt16)
target_ds.SetGeoTransform(geo_transform)
target_ds.SetProjection(projection)
gdal.RasterizeLayer(target_ds, [1], layer, burn_values=[target_value])
return target_ds
def vector_to_raster(file_paths, rows, cols, geo_transform, projection):
labeled_pixels = np.zeros((rows, cols))
for i, path in enumerate(file_paths):
label = i+1
ds = create_mask_from_vector(path, cols, rows, geo_transform, projection, target_value=label)
band = ds.GetRasterBand(1)
labeled_pixels += band.ReadAsArray()
ds = None
return labeled_pixels
def write_geotiff(fname, data, geo_transform, projection):
"""create a geotiff file with the given data"""
driver = gdal.GetDriverByName('GTiff')
rows, cols = data.shape
dataset = driver.Create(fname, cols, rows, 1, gdal.GDT_Byte)
dataset.SetGeoTransform(geo_transform)
dataset.SetProjection(projection)
band = dataset.GetRasterBand(1)
band.WriteArray(data)
dataset = None # close the file
raster_data_path = "E:/python/遥感影像/data/image/2298119ene2016recorteTT.tif"
output_fname = "classification.tif"
train_data_path = "E:/python/遥感影像/data/train/"
validation_data_path = "E:/python/遥感影像/data/test/"
raster_dataset = gdal.Open(raster_data_path, gdal.GA_ReadOnly)
geo_transform = raster_dataset.GetGeoTransform()
proj = raster_dataset.GetProjectionRef()
bands_data = []
for b in range(1, raster_dataset.RasterCount+1):
band = raster_dataset.GetRasterBand(b)
bands_data.append(band.ReadAsArray())
bands_data = np.dstack(bands_data)
rows, cols, n_bands = bands_data.shape
files = [f for f in os.listdir(train_data_path) if f.endswith('.shp')]
classes = [f.split('.')[0] for f in files]
shapefiles = [os.path.join(train_data_path, f) for f in files if f.endswith('.shp')]
labeled_pixels = vector_to_raster(shapefiles, rows, cols, geo_transform, proj)
is_train = np.nonzero(labeled_pixels)
training_labels = labeled_pixels[is_train]
training_samples = bands_data[is_train]
classifier = RandomForestClassifier(n_jobs=-1)
classifier.fit(training_samples, training_labels)
n_samples = rows * cols
flat_pixels = bands_data.reshape((n_samples, n_bands))
result = classifier.predict(flat_pixels)
classification = result.reshape((rows, cols))
f = plt.figure()
r = bands_data[:, :, 3]
g = bands_data[:, :, 2]
b = bands_data[:, :, 1]
rgb = np.dstack([r, g, b])
f.add_subplot(1, 2, 1)
plt.imshow(rgb/255)
f.add_subplot(1,2,2)
plt.imshow(classification)
write_geotiff(output_fname, classification, geo_transform, proj)
shapefiles = [os.path.join(validation_data_path, "%s.shp" % c)
for c in classes]
verification_pixels = vector_to_raster(shapefiles, rows, cols,geo_transform, proj)
for_verification = np.nonzero(verification_pixels)
verification_labels = verification_pixels[for_verification]
predicted_labels = classification[for_verification]
print("Confussion matrix:\n%s" %metrics.confusion_matrix(verification_labels, predicted_labels))
target_names = ['Class %s' % s for s in classes]
print("Classification report:\n%s" %metrics.classification_report(verification_labels, predicted_labels,target_names=target_names))
print("Classification accuracy: %f" %metrics.accuracy_score(verification_labels, predicted_labels))
效果图:
分类报告:
Confussion matrix:
[[ 65 0 0 0]
[ 1 87 0 0]
[ 0 0 180 0]
[ 0 0 0 160]]
Classification report:
precision recall f1-score support
Class B 0.98 1.00 0.99 65
Class C 1.00 0.99 0.99 88
Class D 1.00 1.00 1.00 180
Class E 1.00 1.00 1.00 160
accuracy 1.00 493
macro avg 1.00 1.00 1.00 493
weighted avg 1.00 1.00 1.00 493
Classification accuracy: 0.997972
参考:https://zhuanlan.zhihu.com/p/73793438