监督分类:用SVM做遥感影像特征级分类

30 篇文章 30 订阅

前面已经做了很多像素级的分类了,这里继续深入,这里用的特征是LBP(局部阈值模式),思路就是先手动选点取样本点邻域并提取LBP特征进行训练得到模型,然后取每个像素的邻域提取LBP特征,然后用训练好的模型对每个像素对应邻域的LBP特征进行判断,最终确定像素的类型
这里放了一个初版和一个改进后的版本,初版里面有想记录的东西所以就也放这里了,大家可以直接往下翻看终版。

# -*- coding: utf-8 -*-
import os, sys, time
import warnings
import gdal
import numpy as np
from numpy import average, dot, linalg
import cv2
import skimage
from osgeo import ogr
from osgeo import gdal
from osgeo import gdal_array as ga
from gdalconst import *
from skimage.feature import local_binary_pattern
from skimage.util.shape import view_as_windows #这是影像切片的模块,存在边界无法处理的缺陷
from sklearn.svm import SVC
from numba import jit

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

@jit
def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
    out = np.zeros_like(bands).astype(np.float32)
    # a = 0
    # b = 65535
    a = img_min
    b = img_max
    c = np.percentile(bands[:, :], lower_percent)
    d = np.percentile(bands[:, :], higher_percent)
    x = d-c
    if (x==0).any():
        t = 0
    else:
        t = a + (bands[:, :] - c) * (b - a) / (d - c)
        t[t < a] = a
        t[t > b] = b
        out[:, :] = t
    
    out = np.uint8(out)
    return out

def getPixels(shp, img, size):
    driver = ogr.GetDriverByName('ESRI Shapefile')
    ds = driver.Open(shp, 0)
    if ds is None:
        print('Could not open ' + shp)
        sys.exit(1)

    layer = ds.GetLayer()

    xValues = []
    yValues = []
    feature = layer.GetNextFeature()
    while feature:
        geometry = feature.GetGeometryRef()
        x = geometry.GetX()
        y = geometry.GetY()
        xValues.append(x)
        yValues.append(y)
        feature = layer.GetNextFeature()

    gdal.AllRegister()

    ds = gdal.Open(img, GA_ReadOnly)
    if ds is None:
        print('Could not open image')
        sys.exit(1)

    rows = ds.RasterYSize
    cols = ds.RasterXSize
    bands = ds.RasterCount

    transform = ds.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    values = []
    for i in range(len(xValues)):
        x = xValues[i]
        y = yValues[i]
        new_transform=list(transform)
        #print new_transform
        new_transform[0]=x-im_geotrans[1]*int(size)/2.0
        new_transform[3]=y-im_geotrans[5]*int(size)/2.0
        new_transfor_mtuple=tuple(new_transform)
        x1=x-int(size)/2*transform[1]
        y1=y-int(size)/2*transform[5]
        x2=x+int(size)/2*transform[1]
        y2=y-int(size)/2*transform[5]
        x3=x-int(size)/2*transform[1]
        y3=y+int(size)/2*transform[5]
        x4=x+int(size)/2*transform[1]
        y4=y+int(size)/2*transform[5]
        Xpix=(x1-transform[0])/transform[1]         
        Ypix=(new_transform[3]-transform[3])/transform[5]
        data = ds.ReadAsArray(int(Xpix),int(Ypix),int(size),int(size))

        values.append([data, new_transfor_mtuple])
    return values

def lbp(img, n_points, radius, level=256, method='default'):  #这里注意256位与8bit对应
    lbp = local_binary_pattern(img, n_points, radius, method)
    # n_bins = int(lbp.max() + 1)
    n_bins = level
    hist, _ = np.histogram(lbp, density=True, bins=n_bins, range=(0, n_bins))
    return hist
    
def get_data(data_point, img_path, patch_path, im_proj, band, size, radius, n_points, data_type='false'):
    data_patches = getPixels(data_point, img_path, size)
    data_lbp_hists = []
    data_label = []
    count = 0
    for patch in data_patches:
        patch_p = os.path.join(patch_path, data_type + '_' + str(count) + '.tif')
        write_img(patch_p, im_proj, patch[1], patch[0])
        count += 1
        temp = []
        for i in range(band):
            hist = lbp(patch[0][i,...], n_points, radius)
            # print(hist.shape)
            temp.append(hist)
        temp_arr = temp[0]
        for j in range(band-1):
            temp_arr = np.concatenate((temp_arr,temp[j+1]),axis=0) 
        data_lbp_hists.append(temp_arr)
        if data_type == 'false':
            data_label.append(0)
        else:
            data_label.append(1)
    return data_lbp_hists, data_label

if __name__ == "__main__":
    warnings.filterwarnings("ignore")

    img_path = "E:/20200210/forest/gf2/dys_gf2.tif"   #大图
    false_point = 'E:/20200210/forest/gf2/point/1.shp' #负样本点
    true_point = 'E:/20200210/forest/gf2/point/2.shp'  #正样本点
    patch_path = "E:/20200210/forest/tif_temp/patches2/" #样本点对应的图像块,以样本点为中心裁剪
    
    im_proj,im_geotrans,im_width, im_height,im_data = read_img(img_path)
    temp8bit_path = "E:/20200210/forest/gf2/dys_gf2_8bit.tif"
    temp8bit = stretch_n(im_data,0,255)
    write_img(temp8bit_path, im_proj, im_geotrans, temp8bit)  #影像原本是16bit的,这里转成8bit,减小计算量

    band = 4
    size = 10  #邻域大小
    
    radius = 2  #LBP参数
    n_points = 8 * radius  #LBP参数
    
    false_lbp_hists, false_label = get_data(false_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points, data_type='false')

    true_lbp_hists, true_label = get_data(true_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points,data_type='true')
    
    train_data = np.array(true_lbp_hists + false_lbp_hists)
    train_label = np.array(true_label + false_label)

    svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000) #这里用了SVM,也可以选用别的方法训练模型,比如以前提到的随机森林
    svc.fit(train_data, train_label)
    
    test_area = "E:/20200210/forest/gf2/dys_gf2_test.tif"  #测试区域,有点慢,用小点的。
    ds = gdal.Open(test_area)
    im_width0 = ds.RasterXSize
    im_height0 = ds.RasterYSize
    im_geotrans0 = ds.GetGeoTransform()
    im_proj0 = ds.GetProjection()
    test_data = ds.ReadAsArray(0,0,im_width0,im_height0)
    test_data = stretch_n(test_data, 0, 255)

    # xlength = int((im_width+0.0)/size)
    # ylength = int((im_height+0.0)/size)
    xlength = im_width0 - size
    ylength = im_height0 - size
    window_shape = (4,10,10)  #切片大小,应该和训练时的切片大小一致,也就是上面的size
    windows = view_as_windows(test_data, window_shape)  #自动切图,步长是1,也就是说右边界和下边界的10个像素都无法处理,感兴趣可以搜索一下这个函数,挺不错的,至少速度很快。
    windows = np.squeeze(windows)

    all_arr = []
    for i in xrange(ylength):
        for j in xrange(xlength):
            temp = []
            for h in xrange(band):
                hist = lbp(windows[i,j][h,...], n_points, radius)
                temp.append(hist)
            temp_arr = temp[0]
            for k in xrange(band-1):
                temp_arr = np.concatenate((temp_arr,temp[k+1]),axis=0)
            all_arr.append(temp_arr)
  
    predict = svc.predict(np.array(all_arr))
    re = predict.reshape((im_height0-size, im_width0-size))  #由于无法处理右边10和下边10个像素,所以预测结果是少了一部分的,需要减去后再reshape
    seg_path = "E:/20200210/forest/tif_temp/dys_seg.tif"
    write_img(seg_path, im_proj0, im_geotrans0, re)
    del ds

改进版本
由于上述的图像切片不但不能处理边界而且每次的切片都是以掩膜左上角为准的邻域,而不是以每个待判别像素为中心的领域,问题很大。下面的代码给图像上下左右增加了padding,辅助判别边界点,代码的整体逻辑没什么问题了,但是效果不太理想,我推测是特征部分有问题,有时间再探究,你们就作为进一步探索的参考吧。

# -*- coding: utf-8 -*-
import os, sys, time
import warnings
import gdal
import numpy as np
from numpy import average, dot, linalg
import cv2
import skimage
from osgeo import ogr
from osgeo import gdal
from osgeo import gdal_array as ga
from gdalconst import *
from skimage.feature import local_binary_pattern
from skimage.util.shape import view_as_windows, view_as_blocks
from sklearn.svm import SVC
from numba import jit

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

@jit
def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
    out = np.zeros_like(bands).astype(np.float32)
    # a = 0
    # b = 65535
    a = img_min
    b = img_max
    c = np.percentile(bands[:, :], lower_percent)
    d = np.percentile(bands[:, :], higher_percent)
    x = d-c
    if (x==0).any():
        t = 0
    else:
        t = a + (bands[:, :] - c) * (b - a) / (d - c)
        t[t < a] = a
        t[t > b] = b
        out[:, :] = t
    
    out = np.uint8(out)
    return out

@jit
def pad_data(data,nei_size):
    c,m,n = data.shape
    t1 = np.zeros([c,nei_size//2,n])
    data = np.concatenate((t1,data,t1),axis=1)
    c,m,n = data.shape
    t2 = np.zeros([c,m,nei_size//2])
    data = np.concatenate((t2,data,t2),axis=2)
    return data

def getPixels(shp, img, size):
    driver = ogr.GetDriverByName('ESRI Shapefile')
    ds = driver.Open(shp, 0)
    if ds is None:
        print('Could not open ' + shp)
        sys.exit(1)

    layer = ds.GetLayer()

    xValues = []
    yValues = []
    feature = layer.GetNextFeature()
    while feature:
        geometry = feature.GetGeometryRef()
        x = geometry.GetX()
        y = geometry.GetY()
        xValues.append(x)
        yValues.append(y)
        feature = layer.GetNextFeature()

    gdal.AllRegister()

    ds = gdal.Open(img, GA_ReadOnly)
    if ds is None:
        print('Could not open image')
        sys.exit(1)

    rows = ds.RasterYSize
    cols = ds.RasterXSize
    bands = ds.RasterCount

    transform = ds.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    values = []
    for i in range(len(xValues)):
        x = xValues[i]
        y = yValues[i]
        new_transform=list(transform)
        #print new_transform
        new_transform[0]=x-im_geotrans[1]*int(size)/2.0
        new_transform[3]=y-im_geotrans[5]*int(size)/2.0
        new_transfor_mtuple=tuple(new_transform)
        x1=x-int(size)/2*transform[1]
        y1=y-int(size)/2*transform[5]
        x2=x+int(size)/2*transform[1]
        y2=y-int(size)/2*transform[5]
        x3=x-int(size)/2*transform[1]
        y3=y+int(size)/2*transform[5]
        x4=x+int(size)/2*transform[1]
        y4=y+int(size)/2*transform[5]
        Xpix=(x1-transform[0])/transform[1]
        #Xpix=(new_transform[0]-transform[0])            
        Ypix=(new_transform[3]-transform[3])/transform[5]
        #Ypix=abs(new_transform[3]-transform[3])
        data = ds.ReadAsArray(int(Xpix),int(Ypix),int(size),int(size))

        values.append([data, new_transfor_mtuple])
    return values

def lbp(img, n_points, radius, level=256, method='default'):
    lbp = local_binary_pattern(img, n_points, radius, method)
    # n_bins = int(lbp.max() + 1)
    n_bins = level
    hist, _ = np.histogram(lbp, density=True, bins=n_bins, range=(0, n_bins))
    return hist

def get_data(data_point, img_path, patch_path, im_proj, band, size, radius, n_points, data_type='false'):
    data_patches = getPixels(data_point, img_path, size)
    data_lbp_hists = []
    data_label = []
    count = 0
    for patch in data_patches:
        patch_p = os.path.join(patch_path, data_type + '_' + str(count) + '.tif')
        write_img(patch_p, im_proj, patch[1], patch[0])
        count += 1
        temp = []
        for i in range(band):
            hist = lbp(patch[0][i,...], n_points, radius)
            # print(hist.shape)
            temp.append(hist)
        temp_arr = temp[0]
        for j in range(band-1):
            temp_arr = np.concatenate((temp_arr,temp[j+1]),axis=0) 
        data_lbp_hists.append(temp_arr)
        if data_type == 'false':
            data_label.append(0)
        else:
            data_label.append(1)
    return data_lbp_hists, data_label

if __name__ == "__main__":
    warnings.filterwarnings("ignore")

    img_path = "E:/20200210/forest/gf2/dys_gf2.tif"
    false_point = 'E:/20200210/forest/gf2/point/1.shp'
    true_point = 'E:/20200210/forest/gf2/point/2.shp'
    patch_path = "E:/20200210/forest/tif_temp/patches2/"
    
    im_proj,im_geotrans,im_width, im_height,im_data = read_img(img_path)
    temp8bit_path = "E:/20200210/forest/gf2/dys_gf2_8bit.tif"
    temp8bit = stretch_n(im_data,0,255)
    write_img(temp8bit_path, im_proj, im_geotrans, temp8bit)

    band = 4
    size = 10
    
    radius = 2
    n_points = 8 * radius
    
    false_lbp_hists, false_label = get_data(false_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points, data_type='false')

    true_lbp_hists, true_label = get_data(true_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points,data_type='true')
    
    train_data = np.array(true_lbp_hists + false_lbp_hists)
    train_label = np.array(true_label + false_label)

    svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000)
    svc.fit(train_data, train_label)
    
    test_area = "E:/20200210/forest/gf2/dys_gf2_test.tif"
    ds = gdal.Open(test_area)
    im_width0 = ds.RasterXSize
    im_height0 = ds.RasterYSize
    im_geotrans0 = ds.GetGeoTransform()
    im_proj0 = ds.GetProjection()
    test_data = ds.ReadAsArray(0,0,im_width0,im_height0)
    test_data = stretch_n(test_data, 0, 255)

    test_data = pad_data(test_data,size)

    all_arr = []
    for i in xrange(size//2, im_height0+size//2):
        for j in xrange(size//2, im_width0+size//2):
            windows = test_data[:,i-size//2:i+size//2+1,j-size//2:j+size//2+1]
            temp = []
            for h in xrange(band):
                hist = lbp(windows[h,...], n_points, radius)
                temp.append(hist)
            temp_arr = temp[0]
            for k in xrange(band-1):
                temp_arr = np.concatenate((temp_arr,temp[k+1]),axis=0)
            all_arr.append(temp_arr)
    
    predict = svc.predict(np.array(all_arr))
    re = predict.reshape((im_height0, im_width0))
    seg_path = "E:/20200210/forest/tif_temp/dys_seg.tif"
    write_img(seg_path, im_proj0, im_geotrans0, re)
    del ds

  • 5
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

如雾如电

随缘

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值