小组python学习任务3

 

任务内容

光学影像分类

1、读取光学影像,格式不限(envi格式,*.tif等)
2、采用监督分类的方法进行分类
3、计算分类的精度和混淆矩阵
4、保存分类的结果文件(envi格式、*.tif等)

提示:
1、数据来源《遥感软件应用与二次开发》书籍
2、大家可以采用envi对原始的数据进行格式转换,但不是采用envi软件分类
3、python监督分类参考库文件sklearn 

任务数据

 

实现过程:

步骤1:数据格式的转换,(本文中是指把roi转换为tif)

使用envi和arcgis,把roi格式转换为tif(或者其他格式)

 

 

步骤2:打开VScode,书写代码

 

完整代码:

#!/usr/bin/env python

#coding=utf-8

from osgeo import gdal

import numpy as np

import numpy

import pandas as pd

from sklearn.ensemble import RandomForestClassifier

from sklearn.decomposition import PCA

from sklearn import svm

from sklearn.cluster import KMeans

from sklearn.metrics import accuracy_score

import datetime

import os




 

class suanfa:

    #--------------------------------------------读取影像的函数--------------------------------------------

    def read_image(self,filename):

        #获取影像的所有信息

        dataset=gdal.Open(filename)

        #获取影像的宽和高

        im_width=dataset.RasterXSize

        im_height=dataset.RasterYSize

        #得到影像的地理,投影,数据的信息

        im_geotrans=dataset.GetGeoTransform()

        im_proj=dataset.GetProjection()

        im_data=dataset.ReadAsArray(0,0,im_width,im_height)

        #清除变量并返回投影,坐标,和数据信息

        del dataset

        return im_proj,im_geotrans,im_data

 

    #--------------------------------------------保存影像的函数--------------------------------------------

    def save_image(self,filename,im_proj,im_geotrans,im_data,out_format):

        if 'int8' in im_data.dtype.name:

            datatype=gdal.GDT_Byte

        elif 'int16' in im_data.dtype.name:

            datatype=gdal.GDT_UInt16

        else:

            datatype=gdal.GDT_Float32

 

        if len(im_data.shape)==3:

            im_bands,im_height,im_width=im_data.shape

        else:

            im_bands,(im_height,im_width)=1,im_data.shape

 

        #输出数据的格式bmp(BMP)、jpg(JPEG)、tif(GTiff)、img(HFA)、bt(BT)、ecw(ECW)、fits(FITS)、gif(GIF)、hdf(HDF4)、hdr(EHdr)、

        driver=gdal.GetDriverByName(out_format)

        dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

 

        dataset.SetGeoTransform(im_geotrans)              

        dataset.SetProjection(im_proj)    

 

        if im_bands == 1:

            dataset.GetRasterBand(1).WriteArray(im_data)  

        else:

            for i in range(im_bands):

                dataset.GetRasterBand(i+1).WriteArray(im_data[i])

 

        del dataset



 

    

if __name__=="__main__":

    starttime = datetime.datetime.now()#记录开始时间

    

    os.chdir(r'E:\0 云南师范大学\桂林理工大学\python学习\2020.1228\0-监督与非监督分类\\2img')

    run=suanfa()

    # 获取光谱数据

    proj,geotrans,data=run.read_image('can_tmr.img')

    #转换数据类型

    data = data.astype(np.float64)

    #获取data的信息

    [m,n1,n2]=data.shape#m是数据的维度,n1是数据的行,n2是数据的列

    data=data.reshape((m,n1*n2)).T

 

    #获取分类标签数据,label为样本标签

    tr_proj,tr_geotrans,tr_label=run.read_image('train_3.tif')

    tr_label=tr_label.reshape((n1*n2,1))

    tr_index=np.where(tr_label<10)

    #获取分类标签数据,label为样本标签

    te_proj,te_geotrans,te_label=run.read_image('test_3.tif')

    te_label=te_label.reshape((n1*n2,1))

    te_index=np.where(te_label<10)


 

    # 获取训练数据集

    tr_data=data[tr_index[0]]

    # 训练标签集

    tr_label=tr_label[tr_index[0]]

    tr_label=np.reshape(tr_label,(-1,))

 

    # 获取训练数据集

    te_data=data[te_index[0]]

    # 测试标签集

    te_label=te_label[te_index[0]]

 

        

    #clf=svm.SVC(kernel='sigmoid')

    '''

SVC参数解释

(1)C: 目标函数的惩罚系数C,用来平衡分类间隔margin和错分样本的,default C = 1.0;

(2)kernel:参数选择有RBF, Linear, Poly, Sigmoid, 默认的是"RBF";

(3)degree:if you choose 'Poly' in param 2, this is effective, degree决定了多项式的最高次幂;

(4)gamma:核函数的系数('Poly', 'RBF' and 'Sigmoid'), 默认是gamma = 1 / n_features;

(5)coef0:核函数中的独立项,'RBF' and 'Poly'有效;

(6)probablity: 可能性估计是否使用(true or false);

(7)shrinking:是否进行启发式;

(8)tol(default = 1e - 3): svm结束标准的精度;

(9)cache_size: 制定训练所需要的内存(以MB为单位);

(10)class_weight: 每个类所占据的权重,不同的类设置不同的惩罚参数C, 缺省的话自适应;

(11)verbose: 跟多线程有关,不大明白啥意思具体;

(12)max_iter: 最大迭代次数,default = 1, if max_iter = -1, no limited;

(13)decision_function_shape : ‘ovo’ 一对一, ‘ovr’ 多对多  or None 无, default=None

(14)random_state :用于概率估计的数据重排时的伪随机数生成器的种子。

 ps:7,8,9一般不考虑。

'''

    svm_clf=svm.SVC(C=1, kernel='poly', gamma=0.14, decision_function_shape='ovr')

    svm_clf.fit(tr_data,tr_label)

 

    y_hat = svm_clf.predict(te_data)

    acc = accuracy_score(te_label, y_hat)

    np.set_printoptions(suppress=True)

    print (u'预测正确的样本个数:%d,svm_正确率:%.2f%%' % (round(acc*2346), 100*acc))

 

    svm_res=svm_clf.predict(data).reshape(n1,n2)

 

    #********************************************************************************************************************

    # 随机森林

    #分类器参数设置

    rf_clf = RandomForestClassifier(n_estimators=500)  

    #训练分类器

    rf_clf.fit(tr_data,tr_label)

    rf_res=rf_clf.predict(data).reshape(n1,n2)

 

    y_hat = rf_clf.predict(te_data)

    acc = accuracy_score(te_label, y_hat)

    np.set_printoptions(suppress=True)

    print (u'预测正确的样本个数:%d,rf_正确率:%.2f%%' % (round(acc*2346), 100*acc))

 

    rf_res=rf_clf.predict(data).reshape(n1,n2)

 

    #************************************************输出分类结果**************************************************

    filename0='result'

    out_format="GTiff"

    run.save_image(filename0+'svm.tif',proj,geotrans,svm_res,out_format)

    run.save_image(filename0+'rf.tif',proj,geotrans,rf_res,out_format)



 

    endtime = datetime.datetime.now()#记录结束时间

    print((endtime - starttime).seconds)


步骤3:运行结果

 

 

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值