想到那天头条面试时,让我手撸kmeans,奈何好久不用c++,好多都忘了==淡淡的忧伤
这次刚好赶上机会,可以再试试了,我写成项目了,有多个文件
首先:base.h
#ifndef BASE_H
#define BASE_H
#include<iostream>
#include<opencv2/opencv.hpp>
#include<cassert>
#include<stdlib.h>
class Baseofgeo{
public:
float computedis(const std::vector<float> p1,const std::vector<float> p2);
void Gekmeasns(std::vector<std::vector<float> >& listA,int K);
private:
//to do...
};
#endif
再上:base.cpp
#include"base.h"
void Baseofgeo::Gekmeasns(std::vector<std::vector<float> >& listA,int K){
//初始化
srand((unsigned)time(NULL));
std::vector<std::vector<float> > centerid(K,std::vector<float> (3,0.0));
int len=listA.size();
for (int i=0;i<K;){
int randomindex=rand()%len;
if(listA[randomindex][3]!=-1.0){continue;}
listA[randomindex][3]=i;
centerid[i][0]=listA[randomindex][0];
centerid[i][1]=listA[randomindex][1];
centerid[i][2]=listA[randomindex][2];
i++;
}
//计算距离
int count=0;//迭代次数
float J1=0.0;//记录上一次迭代后的类内距离
float reserror=100.0;//记录连续两次类内距离的变化
while( count<=20 && reserror>10.0){
std::cout<<"第"<<count<<"次迭代"<<std::endl;
//对每个点遍历,计算与其最接近的中心,并赋上类别
for(std::vector<float>& p1:listA){
std::vector<float> pp1(3,0.0);
pp1[0]=p1[0];pp1[1]=p1[1];pp1[2]=p1[2];
float mindist=99999999.0;
for(int j=0;j<centerid.size();j++){
std::vector<float> p2=centerid[j];
float distt=computedis(pp1,p2);
if(distt<mindist){
mindist=distt;
p1[3]=(float)j;
}
}
}
//重新计算中心
for(int i=0;i<K;i++){
std::vector<float> sum(3,0.0); float numb=0.0;
for(std::vector<float> p:listA){
if((int)p[3]==i){
sum[0]+=p[0];sum[1]+=p[1];sum[2]+=p[2];
numb++;
}
}
assert(numb!=0);
sum[0]/=numb;sum[1]/=numb;sum[2]/=numb;
centerid[i]=sum;
}
//计算终止条件1
count++;
//计算终止条件2
float J=0.0;
for(int i=0;i<K;i++){
for(std::vector<float> p:listA ){
if ((int)p[3]==i){
std::vector<float> ptem(3,0.0);
ptem[0]=p[0];ptem[1]=p[1];ptem[2]=p[2];
J+=computedis(ptem,centerid[i]);
}
}
}
//
if (count==1){
//记录上次的类内距离之和
J1=J;
}else{
reserror=std::abs(J-J1);
std::cout<<"reserror:"<<reserror<<std::endl;
//记录上次的类内距离之和
J1=J;
}
}
}
float Baseofgeo::computedis(const std::vector<float> p1,const std::vector<float> p2){
assert(p1.size()==3&&p2.size()==3);
return std::sqrt(pow(p1[0]-p2[0],2)+pow(p1[1]-p2[1],2)+pow(p1[2]-p2[2],2));
}
主函数嘛:main.cpp
#include<iostream>
#include<vector>
#include<boost/concept_check.hpp>
#include<opencv2/opencv.hpp>
#include<time.h>
#include<stdlib.h>
#include<cassert>
#include<memory>
#include"base.h"
int main(int argc,char** argv){
cv::Mat I=cv::imread("../data/0001.jpg");
cv::imshow("im a pic",I);
std::vector<std::vector<float> > listA(I.cols*I.rows,std::vector<float>(4,-1.0));
int nl=d_Ihsv.rows;
int nc=d_Ihsv.cols;
int ii=0;
for(int i=0;i<nl;i++){
for(int j=0;j<nc;j++){
listA[ii][0]=I.at<cv::Vec3f>(i,j)[0]/10.0;
listA[ii][1]=I.at<cv::Vec3f>(i,j)[1]/10.0;
listA[ii][2]=I.at<cv::Vec3f>(i,j)[2]/10.0;
ii++;
}
}
std::shared_ptr<Baseofgeo> basemethod(new Baseofgeo());
basemethod->Gekmeasns(listA,72);
//就当我是可视化==,可视化第二期再更
for(auto p:listA){
std::cout<<p[3]<<std::endl;
}
return 0;
}
还有CMakeLists.txt
cmake_minimum_required( VERSION 2.8 ) project ( image ) set(OpenCV_DIR "/home/geo/opencv-2.4.13/build") add_compile_options(-std=c++11) find_package(OpenCV REQUIRED) add_library(base base.cpp) target_link_libraries(
base ${OpenCV_LIBS}) add_executable(main main.cpp) target_link_libraries(main base ${OpenCV_LIBS})