前面已经做了很多像素级的分类了,这里继续深入,这里用的特征是LBP(局部阈值模式),思路就是先手动选点取样本点邻域并提取LBP特征进行训练得到模型,然后取每个像素的邻域提取LBP特征,然后用训练好的模型对每个像素对应邻域的LBP特征进行判断,最终确定像素的类型
这里放了一个初版和一个改进后的版本,初版里面有想记录的东西所以就也放这里了,大家可以直接往下翻看终版。
# -*- coding: utf-8 -*-
import os, sys, time
import warnings
import gdal
import numpy as np
from numpy import average, dot, linalg
import cv2
import skimage
from osgeo import ogr
from osgeo import gdal
from osgeo import gdal_array as ga
from gdalconst import *
from skimage.feature import local_binary_pattern
from skimage.util.shape import view_as_windows #这是影像切片的模块,存在边界无法处理的缺陷
from sklearn.svm import SVC
from numba import jit
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
del dataset
return im_proj,im_geotrans,im_width, im_height,im_data
def write_img(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
@jit
def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
out = np.zeros_like(bands).astype(np.float32)
# a = 0
# b = 65535
a = img_min
b = img_max
c = np.percentile(bands[:, :], lower_percent)
d = np.percentile(bands[:, :], higher_percent)
x = d-c
if (x==0).any():
t = 0
else:
t = a + (bands[:, :] - c) * (b - a) / (d - c)
t[t < a] = a
t[t > b] = b
out[:, :] = t
out = np.uint8(out)
return out
def getPixels(shp, img, size):
driver = ogr.GetDriverByName('ESRI Shapefile')
ds = driver.Open(shp, 0)
if ds is None:
print('Could not open ' + shp)
sys.exit(1)
layer = ds.GetLayer()
xValues = []
yValues = []
feature = layer.GetNextFeature()
while feature:
geometry = feature.GetGeometryRef()
x = geometry.GetX()
y = geometry.GetY()
xValues.append(x)
yValues.append(y)
feature = layer.GetNextFeature()
gdal.AllRegister()
ds = gdal.Open(img, GA_ReadOnly)
if ds is None:
print('Could not open image')
sys.exit(1)
rows = ds.RasterYSize
cols = ds.RasterXSize
bands = ds.RasterCount
transform = ds.GetGeoTransform()
xOrigin = transform[0]
yOrigin = transform[3]
pixelWidth = transform[1]
pixelHeight = transform[5]
values = []
for i in range(len(xValues)):
x = xValues[i]
y = yValues[i]
new_transform=list(transform)
#print new_transform
new_transform[0]=x-im_geotrans[1]*int(size)/2.0
new_transform[3]=y-im_geotrans[5]*int(size)/2.0
new_transfor_mtuple=tuple(new_transform)
x1=x-int(size)/2*transform[1]
y1=y-int(size)/2*transform[5]
x2=x+int(size)/2*transform[1]
y2=y-int(size)/2*transform[5]
x3=x-int(size)/2*transform[1]
y3=y+int(size)/2*transform[5]
x4=x+int(size)/2*transform[1]
y4=y+int(size)/2*transform[5]
Xpix=(x1-transform[0])/transform[1]
Ypix=(new_transform[3]-transform[3])/transform[5]
data = ds.ReadAsArray(int(Xpix),int(Ypix),int(size),int(size))
values.append([data, new_transfor_mtuple])
return values
def lbp(img, n_points, radius, level=256, method='default'): #这里注意256位与8bit对应
lbp = local_binary_pattern(img, n_points, radius, method)
# n_bins = int(lbp.max() + 1)
n_bins = level
hist, _ = np.histogram(lbp, density=True, bins=n_bins, range=(0, n_bins))
return hist
def get_data(data_point, img_path, patch_path, im_proj, band, size, radius, n_points, data_type='false'):
data_patches = getPixels(data_point, img_path, size)
data_lbp_hists = []
data_label = []
count = 0
for patch in data_patches:
patch_p = os.path.join(patch_path, data_type + '_' + str(count) + '.tif')
write_img(patch_p, im_proj, patch[1], patch[0])
count += 1
temp = []
for i in range(band):
hist = lbp(patch[0][i,...], n_points, radius)
# print(hist.shape)
temp.append(hist)
temp_arr = temp[0]
for j in range(band-1):
temp_arr = np.concatenate((temp_arr,temp[j+1]),axis=0)
data_lbp_hists.append(temp_arr)
if data_type == 'false':
data_label.append(0)
else:
data_label.append(1)
return data_lbp_hists, data_label
if __name__ == "__main__":
warnings.filterwarnings("ignore")
img_path = "E:/20200210/forest/gf2/dys_gf2.tif" #大图
false_point = 'E:/20200210/forest/gf2/point/1.shp' #负样本点
true_point = 'E:/20200210/forest/gf2/point/2.shp' #正样本点
patch_path = "E:/20200210/forest/tif_temp/patches2/" #样本点对应的图像块,以样本点为中心裁剪
im_proj,im_geotrans,im_width, im_height,im_data = read_img(img_path)
temp8bit_path = "E:/20200210/forest/gf2/dys_gf2_8bit.tif"
temp8bit = stretch_n(im_data,0,255)
write_img(temp8bit_path, im_proj, im_geotrans, temp8bit) #影像原本是16bit的,这里转成8bit,减小计算量
band = 4
size = 10 #邻域大小
radius = 2 #LBP参数
n_points = 8 * radius #LBP参数
false_lbp_hists, false_label = get_data(false_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points, data_type='false')
true_lbp_hists, true_label = get_data(true_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points,data_type='true')
train_data = np.array(true_lbp_hists + false_lbp_hists)
train_label = np.array(true_label + false_label)
svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000) #这里用了SVM,也可以选用别的方法训练模型,比如以前提到的随机森林
svc.fit(train_data, train_label)
test_area = "E:/20200210/forest/gf2/dys_gf2_test.tif" #测试区域,有点慢,用小点的。
ds = gdal.Open(test_area)
im_width0 = ds.RasterXSize
im_height0 = ds.RasterYSize
im_geotrans0 = ds.GetGeoTransform()
im_proj0 = ds.GetProjection()
test_data = ds.ReadAsArray(0,0,im_width0,im_height0)
test_data = stretch_n(test_data, 0, 255)
# xlength = int((im_width+0.0)/size)
# ylength = int((im_height+0.0)/size)
xlength = im_width0 - size
ylength = im_height0 - size
window_shape = (4,10,10) #切片大小,应该和训练时的切片大小一致,也就是上面的size
windows = view_as_windows(test_data, window_shape) #自动切图,步长是1,也就是说右边界和下边界的10个像素都无法处理,感兴趣可以搜索一下这个函数,挺不错的,至少速度很快。
windows = np.squeeze(windows)
all_arr = []
for i in xrange(ylength):
for j in xrange(xlength):
temp = []
for h in xrange(band):
hist = lbp(windows[i,j][h,...], n_points, radius)
temp.append(hist)
temp_arr = temp[0]
for k in xrange(band-1):
temp_arr = np.concatenate((temp_arr,temp[k+1]),axis=0)
all_arr.append(temp_arr)
predict = svc.predict(np.array(all_arr))
re = predict.reshape((im_height0-size, im_width0-size)) #由于无法处理右边10和下边10个像素,所以预测结果是少了一部分的,需要减去后再reshape
seg_path = "E:/20200210/forest/tif_temp/dys_seg.tif"
write_img(seg_path, im_proj0, im_geotrans0, re)
del ds
改进版本
由于上述的图像切片不但不能处理边界而且每次的切片都是以掩膜左上角为准的邻域,而不是以每个待判别像素为中心的领域,问题很大。下面的代码给图像上下左右增加了padding,辅助判别边界点,代码的整体逻辑没什么问题了,但是效果不太理想,我推测是特征部分有问题,有时间再探究,你们就作为进一步探索的参考吧。
# -*- coding: utf-8 -*-
import os, sys, time
import warnings
import gdal
import numpy as np
from numpy import average, dot, linalg
import cv2
import skimage
from osgeo import ogr
from osgeo import gdal
from osgeo import gdal_array as ga
from gdalconst import *
from skimage.feature import local_binary_pattern
from skimage.util.shape import view_as_windows, view_as_blocks
from sklearn.svm import SVC
from numba import jit
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
del dataset
return im_proj,im_geotrans,im_width, im_height,im_data
def write_img(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
@jit
def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
out = np.zeros_like(bands).astype(np.float32)
# a = 0
# b = 65535
a = img_min
b = img_max
c = np.percentile(bands[:, :], lower_percent)
d = np.percentile(bands[:, :], higher_percent)
x = d-c
if (x==0).any():
t = 0
else:
t = a + (bands[:, :] - c) * (b - a) / (d - c)
t[t < a] = a
t[t > b] = b
out[:, :] = t
out = np.uint8(out)
return out
@jit
def pad_data(data,nei_size):
c,m,n = data.shape
t1 = np.zeros([c,nei_size//2,n])
data = np.concatenate((t1,data,t1),axis=1)
c,m,n = data.shape
t2 = np.zeros([c,m,nei_size//2])
data = np.concatenate((t2,data,t2),axis=2)
return data
def getPixels(shp, img, size):
driver = ogr.GetDriverByName('ESRI Shapefile')
ds = driver.Open(shp, 0)
if ds is None:
print('Could not open ' + shp)
sys.exit(1)
layer = ds.GetLayer()
xValues = []
yValues = []
feature = layer.GetNextFeature()
while feature:
geometry = feature.GetGeometryRef()
x = geometry.GetX()
y = geometry.GetY()
xValues.append(x)
yValues.append(y)
feature = layer.GetNextFeature()
gdal.AllRegister()
ds = gdal.Open(img, GA_ReadOnly)
if ds is None:
print('Could not open image')
sys.exit(1)
rows = ds.RasterYSize
cols = ds.RasterXSize
bands = ds.RasterCount
transform = ds.GetGeoTransform()
xOrigin = transform[0]
yOrigin = transform[3]
pixelWidth = transform[1]
pixelHeight = transform[5]
values = []
for i in range(len(xValues)):
x = xValues[i]
y = yValues[i]
new_transform=list(transform)
#print new_transform
new_transform[0]=x-im_geotrans[1]*int(size)/2.0
new_transform[3]=y-im_geotrans[5]*int(size)/2.0
new_transfor_mtuple=tuple(new_transform)
x1=x-int(size)/2*transform[1]
y1=y-int(size)/2*transform[5]
x2=x+int(size)/2*transform[1]
y2=y-int(size)/2*transform[5]
x3=x-int(size)/2*transform[1]
y3=y+int(size)/2*transform[5]
x4=x+int(size)/2*transform[1]
y4=y+int(size)/2*transform[5]
Xpix=(x1-transform[0])/transform[1]
#Xpix=(new_transform[0]-transform[0])
Ypix=(new_transform[3]-transform[3])/transform[5]
#Ypix=abs(new_transform[3]-transform[3])
data = ds.ReadAsArray(int(Xpix),int(Ypix),int(size),int(size))
values.append([data, new_transfor_mtuple])
return values
def lbp(img, n_points, radius, level=256, method='default'):
lbp = local_binary_pattern(img, n_points, radius, method)
# n_bins = int(lbp.max() + 1)
n_bins = level
hist, _ = np.histogram(lbp, density=True, bins=n_bins, range=(0, n_bins))
return hist
def get_data(data_point, img_path, patch_path, im_proj, band, size, radius, n_points, data_type='false'):
data_patches = getPixels(data_point, img_path, size)
data_lbp_hists = []
data_label = []
count = 0
for patch in data_patches:
patch_p = os.path.join(patch_path, data_type + '_' + str(count) + '.tif')
write_img(patch_p, im_proj, patch[1], patch[0])
count += 1
temp = []
for i in range(band):
hist = lbp(patch[0][i,...], n_points, radius)
# print(hist.shape)
temp.append(hist)
temp_arr = temp[0]
for j in range(band-1):
temp_arr = np.concatenate((temp_arr,temp[j+1]),axis=0)
data_lbp_hists.append(temp_arr)
if data_type == 'false':
data_label.append(0)
else:
data_label.append(1)
return data_lbp_hists, data_label
if __name__ == "__main__":
warnings.filterwarnings("ignore")
img_path = "E:/20200210/forest/gf2/dys_gf2.tif"
false_point = 'E:/20200210/forest/gf2/point/1.shp'
true_point = 'E:/20200210/forest/gf2/point/2.shp'
patch_path = "E:/20200210/forest/tif_temp/patches2/"
im_proj,im_geotrans,im_width, im_height,im_data = read_img(img_path)
temp8bit_path = "E:/20200210/forest/gf2/dys_gf2_8bit.tif"
temp8bit = stretch_n(im_data,0,255)
write_img(temp8bit_path, im_proj, im_geotrans, temp8bit)
band = 4
size = 10
radius = 2
n_points = 8 * radius
false_lbp_hists, false_label = get_data(false_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points, data_type='false')
true_lbp_hists, true_label = get_data(true_point,temp8bit_path,patch_path,im_proj,band,size,radius,n_points,data_type='true')
train_data = np.array(true_lbp_hists + false_lbp_hists)
train_label = np.array(true_label + false_label)
svc = SVC(C=0.8, kernel='rbf', gamma='scale', cache_size=1000)
svc.fit(train_data, train_label)
test_area = "E:/20200210/forest/gf2/dys_gf2_test.tif"
ds = gdal.Open(test_area)
im_width0 = ds.RasterXSize
im_height0 = ds.RasterYSize
im_geotrans0 = ds.GetGeoTransform()
im_proj0 = ds.GetProjection()
test_data = ds.ReadAsArray(0,0,im_width0,im_height0)
test_data = stretch_n(test_data, 0, 255)
test_data = pad_data(test_data,size)
all_arr = []
for i in xrange(size//2, im_height0+size//2):
for j in xrange(size//2, im_width0+size//2):
windows = test_data[:,i-size//2:i+size//2+1,j-size//2:j+size//2+1]
temp = []
for h in xrange(band):
hist = lbp(windows[h,...], n_points, radius)
temp.append(hist)
temp_arr = temp[0]
for k in xrange(band-1):
temp_arr = np.concatenate((temp_arr,temp[k+1]),axis=0)
all_arr.append(temp_arr)
predict = svc.predict(np.array(all_arr))
re = predict.reshape((im_height0, im_width0))
seg_path = "E:/20200210/forest/tif_temp/dys_seg.tif"
write_img(seg_path, im_proj0, im_geotrans0, re)
del ds