#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int INF=0x3f3f3f3f;
const int MAX_M=1e4+5;
const int MAX_K=30+5;
const int MAX_N=1e6+5;
struct node{
int num;
double top[MAX_K];
node(){
num=0;
memset(top,0,sizeof(top));
}
};
int M,K,N;//文档数、主题数、词数
string dic[MAX_N+1];//词表
map<string,int> mp;//词表map
double doc[MAX_M+1][MAX_K+1];//每篇文档的主题分布
double top[MAX_K+1][MAX_N+1];//每个主题的词分布
map<int,node> doc_words[MAX_M+1];//隐含变量,指示器
vector<string> data[MAX_M];//测试数据
int Rand(int n){//生成[0,n)之间的整合
return (double)rand()/(RAND_MAX+1)*n;
}
void init(){
for(int i=1;i<=M;i++){
double s=0;
for(int j=1;j<=K;j++){
doc[i][j]=Rand(100)+1;
s+=doc[i][j];
}
for(int j=1;j<=K;j++)doc[i][j]/=s;
}
for(int i=1;i<=K;i++){
double s=0;
for(int j=1;j<=N;j++){
top[i][j]=Rand(100)+1;
s+=top[i][j];
}
for(int j=1;j<=N;j++)top[i][j]/=s;
}
}
void get_E(){
for(int i=1;i<=M;i++){
for(map<int,node>::iterator it=doc_words[i].begin();it!=doc_words[i].end();it++){
int id=it->first;
node &w=it->second;
double s=0;
for(int j=1;j<=K;j++)s+=doc[i][j]*top[j][id];
for(int j=1;j<=K;j++)w.top[j]=doc[i][j]*top[j][id]/s;
}
}
}
void get_M(){
for(int i=1;i<=K;i++){
double s=0;
fill(top[i],top[i]+N+1,0);
for(int j=1;j<=M;j++){
for(map<int,node>::iterator it=doc_words[j].begin();it!=doc_words[j].end();it++){
int id=it->first;
node &w=it->second;
top[i][id]+=w.top[i]*w.num;
s+=w.top[i]*w.num;
}
}
for(int j=1;j<=N;j++){
top[i][j]=top[i][j]/s;
}
}
for(int i=1;i<=M;i++){
double s=0;
fill(doc[i],doc[i]+K+1,0);
for(map<int,node>::iterator it=doc_words[i].begin();it!=doc_words[i].end();it++){
node &w=it->second;
for(int j=1;j<=K;j++){
doc[i][j]+=w.top[j]*w.num;
s+=w.top[j]*w.num;
}
}
for(int j=1;j<=K;j++){
doc[i][j]=doc[i][j]/s;
}
}
}
double get_MLE(){
double res=0;
for(int i=1;i<=M;i++){
for(map<int,node>::iterator it=doc_words[i].begin();it!=doc_words[i].end();it++){
int id=it->first;
double p=0,c=it->second.num;
for(int j=1;j<=K;j++)p+=doc[i][j]*top[j][id];
res+=c*log(p);
}
}
return res;
}
int kkk;
int rak[MAX_N+1];
bool cmp(int i,int j){
return top[kkk][i]>top[kkk][j];
}
void show(){
for(int i=1;i<=K;i++){
cout<<i<<":";
kkk=i;
for(int j=1;j<=N;j++)rak[j]=j;
sort(rak+1,rak+N+1,cmp);
for(int j=1;j<=10;j++){
int id=rak[j];
cout<<dic[id]<<"*";
printf("%.4f ",top[i][id]);
}
cout<<endl;
}
}
int main(){
srand(time(NULL));
freopen("C:\\Users\\28612\\Desktop\\data","r",stdin);
M=0,K=20,N=0;
cout<<"主题数K:"<<K<<endl;
int t;
while(cin>>t){
M++;
while(t--){
string s;
cin>>s;
data[M].push_back(s);
}
}
cout<<"文档数M:"<<M<<endl;
for(int i=1;i<=M;i++){
for(int j=0;j<data[i].size();j++){
string s=data[i][j];
if(mp[s]==0){
dic[++N]=s;
mp[s]=N;
}
}
}
cout<<"词汇数N:"<<N<<endl;
for(int i=1;i<=M;i++){
for(int j=0;j<data[i].size();j++){
string s=data[i][j];
int id=mp[s];
doc_words[i][id].num++;
}
}
init();
double last=-1e10;
for(int i=1;i<=1000;i++){
get_E();
get_M();
double p=get_MLE();
cout<<i<<":"<<p<<endl;
if(abs(p-last)<0.1)break;
last=p;
}
show();
return 0;
}
参考:
https://www.cnblogs.com/zhangchaoyang/articles/5668024.html