监督分类:用随机森林做遥感影像像素级分类

30 篇文章 30 订阅
3 篇文章 1 订阅

像素级分类有点耗时间,能力有限,在此先提供一个初版
这里导入了传统决策树、随机森林、极端决策树,大家可以都试下
下面是跑了4波段遥感影像的代码,最好选个小图,我这里逐像素分类很慢,还有很多需要改进的地方,希望大家可以在评论下给我点建议,指导一下
这里是对***植被***进行提取

# -*- coding: utf-8 -*-
import os, sys, time
import gdal
from osgeo import ogr
from osgeo import gdal
from osgeo import gdal_array as ga
from gdalconst import *
from skimage import morphology,filters
import numpy as np
from numba import jit, vectorize, int64
import warnings 
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier,GradientBoostingClassifier,AdaBoostClassifier
from sklearn.ensemble import ExtraTreesClassifier

#读遥感影像
def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data

#写出图像
def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

#根据矢量点获取点对应的像素值,并把点先放入列表中
def getPixels(shp, img):
    driver = ogr.GetDriverByName('ESRI Shapefile')
    ds = driver.Open(shp, 0)
    if ds is None:
        print('Could not open ' + shp)
        sys.exit(1)

    layer = ds.GetLayer()

    xValues = []
    yValues = []
    feature = layer.GetNextFeature()
    while feature:
        geometry = feature.GetGeometryRef()
        x = geometry.GetX()
        y = geometry.GetY()
        xValues.append(x)
        yValues.append(y)
        feature = layer.GetNextFeature()

    gdal.AllRegister()

    ds = gdal.Open(img, GA_ReadOnly)
    if ds is None:
        print('Could not open image')
        sys.exit(1)

    rows = ds.RasterYSize
    cols = ds.RasterXSize
    bands = ds.RasterCount

    transform = ds.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    values = []
    for i in range(len(xValues)):
        x = xValues[i]
        y = yValues[i]

        xOffset = int((x - xOrigin) / pixelWidth)
        yOffset = int((y - yOrigin) / pixelHeight)

        s = str(int(x)) + ' ' + str(int(y)) + ' ' + str(xOffset) + ' ' + str(yOffset) + ' '

        dt = ds.ReadAsArray(xOffset, yOffset, 1, 1)
        values.append(dt.flatten())
    return values

if __name__ == "__main__":
    img_path = "E:/20200210/forest/gf2/dys_gf2.tif"  #原始大图,在这上面选点
    img_path2 = "E:/20200210/forest/gf2/dys_gf2_test.tif"  #测试小图,测试用
    shp_false = "E:/20200210/forest/point/1.shp" #负样本,丰富一点
    shp_true = "E:/20200210/forest/point/2.shp"  #正样本,越多越好,不要和负样本混淆
    temp_path = "E:/20200210/forest/temp/"  #存放临时文件
    
    point_false = getPixels(shp_false, img_path)
    num1 = len(point_false)
    lab_false = np.zeros((num1))

    point_true = getPixels(shp_true, img_path)
    num2 = len(point_true)
    lab_true = np.ones((num2))

    data = point_false + point_true
    label = list(lab_false) + list(lab_true)
    data = np.array(data)
    label = np.array(label)

    clf = RandomForestClassifier(n_estimators=100, max_depth=2,random_state=0)
    clf.fit(data, label)
    # print(clf.feature_importances_)

    im_proj2, im_geotrans2, im_width2, im_height2, im_data2 = read_img(img_path2)
    seg = np.zeros((im_data2.shape[1],im_data2.shape[2]))
    for i in xrange(im_width2-1):
        for j in xrange(im_height2-1):
            point = im_data2[0:4,j,i]
            point = np.expand_dims(point,0)
            seg[j,i] = clf.predict(point)[0]

    seg = np.int8(seg)
    seg_path = os.path.join(temp_path, 'random.tif')
    write_img(seg_path, im_proj2, im_geotrans2, seg)

测试图像
测试图像
输出结果
结果

说明:这里我选点没有完全覆盖所有,所以可能会错提一些,选点一定要好好选。这个结果很粗还需要加些后处理可能会好看点。

希望各位看官给我点加速的建议,这个跑太慢了。对于逐像素遍历怎么才能更快。或者不逐像素遍历能不能实现。

更新:
前面的方式要机械了,其实可以直接将这个图输入到分类器的,下面优化一下,这样快很多!

# -*- coding: utf-8 -*-
import os, sys, time
import gdal
from osgeo import ogr
from osgeo import gdal
from osgeo import gdal_array as ga
from gdalconst import *
from skimage import morphology,filters
import numpy as np
from numba import jit, vectorize, int64
import warnings 
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import ExtraTreesClassifier

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset


def getPixels(shp, img):
    driver = ogr.GetDriverByName('ESRI Shapefile')
    ds = driver.Open(shp, 0)
    if ds is None:
        print('Could not open ' + shp)
        sys.exit(1)

    layer = ds.GetLayer()

    xValues = []
    yValues = []
    feature = layer.GetNextFeature()
    while feature:
        geometry = feature.GetGeometryRef()
        x = geometry.GetX()
        y = geometry.GetY()
        xValues.append(x)
        yValues.append(y)
        feature = layer.GetNextFeature()

    gdal.AllRegister()

    ds = gdal.Open(img, GA_ReadOnly)
    if ds is None:
        print('Could not open image')
        sys.exit(1)

    rows = ds.RasterYSize
    cols = ds.RasterXSize
    bands = ds.RasterCount

    transform = ds.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    values = []
    for i in range(len(xValues)):
        x = xValues[i]
        y = yValues[i]

        xOffset = int((x - xOrigin) / pixelWidth)
        yOffset = int((y - yOrigin) / pixelHeight)

        s = str(int(x)) + ' ' + str(int(y)) + ' ' + str(xOffset) + ' ' + str(yOffset) + ' '

        dt = ds.ReadAsArray(xOffset, yOffset, 1, 1)
        values.append(dt.flatten())
    return values


if __name__ == "__main__":
    img_path = "E:/20200210/forest/gf2/dys_gf2.tif"
    img_path2 = "E:/20200210/forest/gf2/dys_gf2_test.tif"
    shp_false = "E:/20200210/forest/point/1.shp"
    shp_true = "E:/20200210/forest/point/2.shp"
    # test_shp = "E:/20200210/forest/point/test.shp"
    temp_path = "E:/20200210/forest/temp/"
    
    point_false = getPixels(shp_false, img_path)
    num1 = len(point_false)
    lab_false = np.zeros((num1))

    point_true = getPixels(shp_true, img_path)
    num2 = len(point_true)
    lab_true = np.ones((num2))

    data = point_false + point_true
    label = list(lab_false) + list(lab_true)
    data = np.array(data)
    label = np.array(label)

    clf = RandomForestClassifier(n_estimators=100, max_depth=2,random_state=0)
    clf.fit(data, label)

    im_proj2, im_geotrans2, im_width2, im_height2, im_data2 = read_img(img_path2)
    img_arr_temp = im_data2.transpose((2,1,0))
    img_reshape = img_arr_temp.reshape([img_arr_temp.shape[0]*img_arr_temp.shape[1],img_arr_temp.shape[2]])
    seg = clf.predict(img_reshape)
    label = seg.reshape((img_arr_temp.shape[0],img_arr_temp.shape[1]))
    label = label.transpose((1,0))
    seg_path = os.path.join(temp_path, 'random.tif')
    write_img(seg_path, im_proj2, im_geotrans2, label)

新结果

  • 4
    点赞
  • 77
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 21
    评论
评论 21
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

如雾如电

随缘

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值