目录
多光谱遥感分类(一):数据集制作;本文。
多光谱遥感分类(二):VGG微调
多光谱遥感分类(三):CNN提取特征+RF分类
多光谱遥感分类(四):使用GLCM+RF
多光谱遥感分类(五):代码优化+自定义模型
也可参考:遥感分类的一种采样方法 。
描述
代码源于很久以前练手的一个Demo,时间长了许多魔改版的都不见了,目前只剩下此简陋版本。读者如有相关需求,可根据只言片语断章取义。由于代码混乱基础,不再上传GitHub。
所用数据为多光谱遥感影像(.tif,由arcgis导出RGB彩色图像),抠图所得点文件(.shp)(由抠图面文件使用arcgis随机生成点生成,至少有一个字段,即标签)。
工具篇
根据点shp文件(样本点集合),对栅格图像的3、2、1波段切图,并保存在相应标签下的文件夹,注意shp、tif的投影坐标一致
from osgeo import gdal
import numpy as np
import shapefile
import cv2
import os
size=64
bands=3
dataset = gdal.Open(r"E:\数据2\test_tif_peizhun_subset_proj_.tif")
rer=shapefile.Reader(r'E:\shps\test.shp')
def __createDir(path):
if not os.path.exists(path):
try:
os.makedirs(path)
except:
print("创建文件夹失败")
exit(1)
def __getACell(geo,pos):
try:
xoffset = int((pos[0] - geo[0]) / geo[1])
yoffset = int((pos[1] - geo[3]) / geo[5])
print("pixels: x= %d,y= %d" % (xoffset, yoffset))
output = []
for i in [3,2,1]:
band = dataset.GetRasterBand(i)
if (int(xoffset - size / 2) < 0 or int(yoffset - size / 2) < 0
or int(xoffset - size / 2) + size > dataset.RasterXSize
or int(yoffset - size / 2) + size > dataset.RasterYSize):
return None
t = band.ReadAsArray(int(xoffset - size / 2), int(yoffset - size / 2), size, size)
output.append(t)
img = np.moveaxis(np.array(output, dtype=np.uint8), 0, 2)
except:
return None
return img
def getShpDataForNum():
labels=[i[0] for i in rer.records()]
for i in set(labels):
__createDir(os.path.join("data/org/"+str(i)))
for i in range(rer.numRecords):#rer.numRecords
print("deal %d: " % (i+1))
sr=rer.shape(i)
img=__getACell(dataset.GetGeoTransform(), sr.points[0])
if(img is None):
print("the area of points %d is out range." %(i))
continue
label=labels[i]
cv2.imwrite("data/org/%s/%s.%d.jpg" % (label, label, i), img)
print("data/org/%s/%s.%d.jpg" % (label, label, i))
print("deal finish,to numpy array.")
getShpDataForNum()
如下,将上述所得文件拆分为测试集和训练集。
import os
import shutil
import random
def createDir(path):
if not os.path.exists(path):
try:
os.makedirs(path)
except:
print("创建文件夹失败")
exit(1)
createDir("data/train/")
createDir("data/test/")
dir='data/org/'
for dir_item in os.listdir(dir):
createDir("data/train/" + dir_item)
createDir("data/test/"+dir_item)
org_data=os.listdir(dir+dir_item+"/")
random.shuffle(org_data)
num=int(len(org_data)*0.25)
print(dir + dir_item + " start.")
for d in org_data[:-num]:
shutil.copyfile(dir + dir_item + "/" + d, "data/train/" + dir_item + "/" + d)
for d in org_data[-num:]:
shutil.copyfile(dir+dir_item+"/"+d,"data/test/"+dir_item+"/"+d)
print(dir+dir_item+" finished")
以下显示制定文件夹下的子文件夹中的文件数目直方图。
import os
import seaborn as sns
import matplotlib.pyplot as plt
def show(path,title):
d=os.listdir(path)
d_len=[len(os.listdir(os.path.join(path,i))) for i in d]
# print(d,d_len)
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
sns.barplot(d,d_len,)
plt.xlabel("样本类型")
plt.ylabel("数量")
plt.title(title)
for i in range(len(d_len)):
plt.text(i,d_len[i]+2,"%d" % d_len[i],ha="center",va="bottom")
plt.show()
show(r"data/1_train","训练集源数据采样集")
由于其他原因,数据更改。如下为使用shp样本点对应的像素坐标所采图集。此时分为train pos.txt和test pos.txt诸如此类。
from osgeo import gdal
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import cv2,shutil
class Tiff:
def createDir(self, path):
if not os.path.exists(path):
try:
os.makedirs(path)
except:
print("创建文件夹失败")
exit(1)
def __init__(self, pos_src,other_feather,contact_src,size=128,bands=[3,2,1],tif_src=r"D:/lishihang/jiangxia_simple/ZY3_GS_jiangxia1.tif"):
self.dataset = gdal.Open(tif_src) # tif数据
self.size = size # 采样窗口大小
self.bands=bands
self.contact_pos_feather(pos_src, other_feather,contact_src)
self.fea =pd.read_csv(contact_src, header=None)
# shutil.rmtree("data/temp.txt")
def get_cell(self, pos_x, pos_y):
try:
output = []
for i in self.bands:
band = self.dataset.GetRasterBand(i)
t = band.ReadAsArray(int(pos_x - self.size / 2), int(pos_y - self.size / 2), self.size, self.size)
output.append(t)
img2 = np.moveaxis(np.array(output, dtype=np.uint8), 0, 2)
# print(img2.shape)
# self.showImg(img2)
except:
return None
return img2
def get_cells(self,target_src):
fea_len=len(self.fea)
self.createDir(target_src)
for label in set(self.fea.iloc[:,-2]):
self.createDir("%s/%s" % (target_src,label))
print("fea length: %d" % fea_len)
for i in range(fea_len):
temp=self.fea.iloc[i,:].values
img = self.get_cell(temp[1], temp[0])
if img is None:
continue
cv2.imwrite("%s/%s/%s.%d.jpg" % (target_src,temp[-2], temp[-2], i), img)
if(i%1000==0):
print("%d/%d hava finsh save." % (i,fea_len))
def contact_pos_feather(self,pos_src, other_feather,target):
if os.path.exists(target):
print("文件已存在")
return
pos = pd.read_csv(pos_src, header=None, sep=' ')
feather = pd.read_csv(other_feather, header=None, sep='\t')
# fea = pd.concat([pos, feather], axis=1).sample(frac=1).reset_index(drop=True)
fea = pd.concat([pos, feather], axis=1)
print("pos Length=%d,feather Length=%d,fea Length=%d" % (len(pos), len(feather), len(fea)))
# print(type(fea))
del feather
del pos
fea = pd.DataFrame(fea)
fea.to_csv(target, index=None, header=None)
if __name__ == '__main__':
tiff=Tiff(r"D:/tr_sample_1.txt",r"D:/train1.txt",r"tr_1.txt")
# tiff=Tiff(r"D:/te_sample_1.txt",r"D:/test1.txt",r"te_1.txt")
# tiff.get_cells("data/1_test")