监督分类:用随机森林做遥感影像像素级分类(更新:多分类实现)

前面已经发了一个版本了,但是那个看着是二分类,估计很多人也不太好下手改,因为有人问,我就好事做到底吧,来一个多分类吧,大家还可以参考上一篇SVM的更新自己实现一下大影像的分类,我这就不搞重复的了,先上结果。
注:1.可不要看结果不好就不看了喔,这个结果是我随便选的点分的,毕竟做实验,不想浪费太多时间。
2.除了随机森林还有别的方法也在前面import 的时候导入了,你们也可以试下别的呢。
3.如果有用别吝啬你们的赞哈,你们的鼓励就是我的鸡血。
图像
结果

# -*- coding: utf-8 -*-
import os, sys, time
import gdal
from osgeo import ogr
from osgeo import gdal
from osgeo import gdal_array as ga
from gdalconst import *
from skimage import morphology,filters
import numpy as np
from numba import jit, vectorize, int64
import warnings 
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, AdaBoostClassifier
from sklearn.ensemble import ExtraTreesClassifier

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data

def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset


def getPixels(shp, img):
    driver = ogr.GetDriverByName('ESRI Shapefile')
    ds = driver.Open(shp, 0)
    if ds is None:
        print('Could not open ' + shp)
        sys.exit(1)

    layer = ds.GetLayer()

    xValues = []
    yValues = []
    feature = layer.GetNextFeature()
    while feature:
        geometry = feature.GetGeometryRef()
        x = geometry.GetX()
        y = geometry.GetY()
        xValues.append(x)
        yValues.append(y)
        feature = layer.GetNextFeature()

    gdal.AllRegister()

    ds = gdal.Open(img, GA_ReadOnly)
    if ds is None:
        print('Could not open image')
        sys.exit(1)

    rows = ds.RasterYSize
    cols = ds.RasterXSize
    bands = ds.RasterCount

    transform = ds.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    values = []
    for i in range(len(xValues)):
        x = xValues[i]
        y = yValues[i]

        xOffset = int((x - xOrigin) / pixelWidth)
        yOffset = int((y - yOrigin) / pixelHeight)

        s = str(int(x)) + ' ' + str(int(y)) + ' ' + str(xOffset) + ' ' + str(yOffset) + ' '

        dt = ds.ReadAsArray(xOffset, yOffset, 1, 1)
        values.append(dt.flatten())
    return values

def array_change(inlist, outlist):
    for i in range(len(inlist[0])):
        outlist.append([j[i] for j in inlist])
    return outlist

def array_change2(inlist, outlist):
    for ele in inlist:
        for ele2 in ele:
            outlist.append(ele2)
    return outlist

def random_test(img_path, point_path,save_path):
    class_list = []
    label_list = []
    count = 0
    for shp in os.listdir(point_path):
        if shp[-4:] == '.shp':
            shp_full_path = os.path.join(point_path, shp)
            class_type  = getPixels(shp_full_path, img_path)
            class_list += class_type
            label_list += [count]*len(class_type)
            count += 1
    
    arr = np.array(class_list)
    label = np.array(label_list)
    im_proj, im_geotrans, im_width, im_height, im_data = read_img(img_path)
    im_data = im_data.transpose((2,1,0))
    clf = RandomForestClassifier(n_estimators=100, max_depth=2,random_state=0)
    clf.fit(arr, label)
    img_arr_temp = im_data
    img_reshape = img_arr_temp.reshape([img_arr_temp.shape[0]*img_arr_temp.shape[1],img_arr_temp.shape[2]])
    seg = clf.predict(img_reshape)
    re = seg.reshape((img_arr_temp.shape[0],img_arr_temp.shape[1]))
    re = re.transpose((1,0))
    write_img(save_path, im_proj, im_geotrans, re)

if __name__ == "__main__":
    img_path = "D:/data/data/test.tif"
    point_path = "D:data/point2/"
    save_path = "/data/data/test_radom.tif"
    random_test(img_path,point_path,save_path)
  • 67
    点赞
  • 147
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 115
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 115
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

如雾如电

随缘

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值