原理略,代码如下:
#include <iostream>
#include <string>
#include <vector>
#include <cmath>
#include <ctime>
#include <map>
using namespace std;
#define FLT_MAXX 3.402823466e+38F
class Solution {
public:
vector<vector<float>> init_pos(vector<vector<float> >& positions, int k) {
vector<vector<float>> ans;
for(int i=0; i<k; i++) {
ans.push_back(positions[i]);
}
return ans;
}
float cal_len(vector<float>& pos, vector<float>& sta ) {
float ret = pow((pos[0] - sta[0])*(pos[0] - sta[0]) + (pos[1] - sta[1])*(pos[1] - sta[1]), 0.5);
return ret;
}
vector<int> get_father(vector<vector<float> >& positions, vector<vector<float> >& station, int k) {
int n = positions.size();
vector<int> father(n, 0);
for(int i=0; i<n; i++) { // 计算第i个点到第j个站的距离,最小的距离对应的站台标号为fa
float tmp = FLT_MAXX;
int fa;
for(int j=0; j<k; j++) {
float len = cal_len(positions[i], station[j]); // 第i个点和第j个站台的距离
if(len < tmp) {
tmp = len;
fa = j;
}
}
father[i] = fa;
}
return father;
}
void refresh(vector<vector<float> >& positions,vector<vector<float> >& stations, vector<int>& father) {
int k = stations.size();
int n = positions.size();
map<int, vector<vector<float>>> path; // 第i个站台所属的所有样本
for(int i=0; i<n; i++) {
float x = positions[i][0];
float y = positions[i][1];
int fa = father[i];
path[fa].push_back(vector<float> {x, y});
}
for(int i=0; i<k; i++) {
int len = path[i].size();
float x = 0;
float y = 0;
for(int j=0; j<len; j++) {
x += path[i][j][0];
y += path[i][j][1];
}
x /= len;
y /= len;
stations[i] = vector<float> {x, y};
}
}
vector<int> k_means(vector<vector<float> >& positions, int k) {
// srand((unsigned int)time(NULL));
auto station = init_pos(positions, k);
vector<int> father = get_father(positions, station, k); // 计算每个点所属的聚类中心
int epoch = 100;
for(int i=0; i<epoch; i++) {
refresh(positions, station, father); // 更新station
father = get_father(positions, station, k);
}
return father;
}
};
int main()
{
vector<vector<float>> positions = {{1.5,2.1}, {0.8,2.1}, {1.3,2.1}, {110.5,260.6}, {21.7, 32.8},{130.9,150.8},{32.6,40.7},{41.5,24.7}};
Solution base;
auto father = base.k_means(positions, 3);
for(int i=0; i<positions.size(); i++)
cout << father[i] << ' ';
return 0;
}
// 答案:[1,1,1,0,2,0,2,2]