#include <gdal_priv.h>
#include <string>
#include <iostream>
#include <ctime>
#include <vector>
#include <map>
#include <algorithm>
using namespace std;
class Kmeans
{
private:
GDALDataset* m_ds;//待分类的影像
int m_xSize;
int m_ySize;
int m_bandCount;
int m_k;//类别数
int m_n;//最大迭代次数
double m_t;//迭代阈值
vector<vector<double>> m_centers;//当前聚类中心
public:
Kmeans(string fileName,int k,int n,double t)
{
GDALAllRegister();
m_ds = (GDALDataset *)GDALOpen(fileName.c_str(), GDALAccess::GA_ReadOnly);
if (m_ds == NULL)
{
cout << "打开数据失败!" << endl;
return;
}
m_xSize = m_ds->GetRasterXSize();
m_ySize = m_ds->GetRasterYSize();
m_bandCount = m_ds->GetRasterCount();
m_k = k;
m_n = n;
m_t = t;
initCenter();
}
~Kmeans()
{
GDALClose(m_ds);
}
void initCenter()
{
//注意在析构函数中释放内存
GDALDataset::Bands bands = m_ds->GetBands();
srand((unsigned int)time(NULL));
for (int i = 0;i < m_k;i++)
{
vector<double>center;
center.resize(m_bandCount);
int col = rand() % m_xSize;
int row = rand() % m_ySize;
for (int j = 0;j < m_bandCount;j++)
{
GDALRasterBand * band = bands[j];
double value[1];
band->RasterIO(GDALRWFlag::GF_Read, col, row, 1, 1, value, 1, 1, GDALDataType::GDT_Float64, 0, 0, NULL);
center[j] = value[0];
}
m_centers.push_back(center);
}
return;
}
void getdist(const vector<double>pixValues,multimap<double,int>&dists)
{
for (int i = 0;i < m_k;i++)
{
double dist = 0;
vector<double>center = m_centers[i];
for (int j = 0;j < m_bandCount;j++)
{
dist += pow((pixValues[j] - center[j]), 2);
}
dist = sqrt(dist);
dists.insert(pair<double, int>(dist, i));
}
}
void execute(string outName)
{
GDALDataset::Bands bands = m_ds->GetBands();
vector<double *>bandsValue;
for (int i = 0;i < m_bandCount;i++)
{
//注意释放内存
double * bandValue= new double[m_xSize*m_ySize];
bands[i]->RasterIO(GDALRWFlag::GF_Read, 0, 0,m_xSize,m_ySize, bandValue, m_xSize, m_ySize, GDALDataType::GDT_Float64, 0, 0, NULL);
bandsValue.push_back(bandValue);
}
int * categories = new int[m_xSize*m_ySize];//注意释放内存
int N = 0;//当前迭代次数
while (true)
{
vector<int>catesPixNum;
catesPixNum.resize(m_k, 0);
vector<vector<double>>centers;
for (int i = 0;i < m_k;i++)
{
vector<double>center;
center.resize(m_bandCount, 0);
centers.push_back(center);
}
for (int i = 0;i < m_ySize;i++)
{
for (int j = 0;j < m_xSize;j++)
{
int index;
index = i * m_xSize + j;
vector<double>pixValues;
pixValues.resize(m_bandCount);
for (int k = 0;k < m_bandCount;k++)
{
double pixValue = bandsValue[k][index];
pixValues[k] = pixValue;
}
multimap<double, int>dists;
getdist(pixValues, dists);
int cate=dists.begin()->second;
categories[index] = cate;
catesPixNum[cate]++;
for (int k = 0;k < m_bandCount;k++)
{
centers[cate][k] += pixValues[k];
}
}
}
for (int i = 0;i < m_k;i++)
{
for (int j = 0;j < m_bandCount;j++)
{
centers[i][j] /= catesPixNum[i];
}
}
N++;
if (N > m_n)
{
//保存最新的聚类中心
m_centers = centers;
break;
}
//比较两次迭代聚类中心的差值
double maxdist;
for (int i = 0;i < m_k;i++)
{
double dist=0;
for (int j = 0;j < m_bandCount;j++)
{
dist += pow(centers[i][j] - m_centers[i][j], 2);
}
dist = sqrt(dist);
maxdist = i == 0 ? dist : MAX(dist, maxdist);
}
if (maxdist < m_t)
{
m_centers = centers;
break;
}
m_centers = centers;
}
createOut(outName, categories);
}
void createOut(string outName,int * categories)
{
GDALAllRegister();
GDALDriver * driver = GetGDALDriverManager()->GetDriverByName("GTiff");
GDALDataset * outds = driver->Create(outName.c_str(), m_xSize, m_ySize, 1, GDALDataType::GDT_Byte,NULL);
GDALRasterBand * outband = outds->GetRasterBand(1);
outband->RasterIO(GDALRWFlag::GF_Write, 0, 0, m_xSize, m_ySize, categories, m_xSize, m_ySize, GDALDataType::GDT_Byte, 0, 0, NULL);
GDALClose(outds);
return;
}
};
int main()
{
Kmeans kmeans("./kmeans/can.tiff", 3, 100, 0.001);
kmeans.execute("./kmeans/can_kmeans.tif");
system("pause");
return 0;
}
12-08
867
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)