多项式混合模型

===============================================

Multi_Multi_EM


package wangfang_em;



import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;


public class Multi_Multi_EM {


public static void main(String[] args) {
System.out.println("start execute.");
try {
long start = System.currentTimeMillis();
run();
long end = System.currentTimeMillis();
System.out.println("execute time: " + (double) (end - start) / 1000 + " s.");
} catch (IOException e) {
e.printStackTrace();
}
}


static void run() throws IOException {
int[][] matrixA = getData(Config.MATRIX_A_PATH);
int[][] matrixB = getData(Config.MATRIX_B_PATH);


int K = Config.K;// 主题数


int D = matrixA.length;// 文档数,最后一位存储所属的类
int X = matrixA[0].length - 1;// 词数
int Y = matrixB[0].length - 1;// 词数


int maxItr = Config.MAX_ITR;


System.out.println("docsNumber:" + D + ", X:" + X + ", Y:" + Y + ", K:" + K + ", maxItr:" + maxItr);


// double[] lamda_n_T = initNRT(D, matrixA);
// double[] lamda_n_R = initNRT(D, matrixB);


double[] pi_k = initPIK(K);
double[][] theta_ktx = initTheta_kt(K, X);
double[][] theta_kry = initTheta_kt(K, Y);


double[][] g_n_k = init_gnk(D, K);


for (int itr = 0; itr < maxItr; itr++) {// 迭代
System.out.println("start itr: " + itr + " ...");


for (int k = 0; k < K; k++) {// 主题
// 更新pi
double sum_gnk = 0;
for (int i = 0; i < D; i++) {// 文档
sum_gnk += g_n_k[i][k];
}
pi_k[k] = sum_gnk / D;


// 更新theta_ktx
double fenmu = 0;
for (int x = 0; x < X; x++) {// 词
double fenzi = 1;
for (int i = 0; i < D; i++) {// 文档
fenzi += matrixA[i][x] * g_n_k[i][k];
}
fenmu += fenzi;
theta_ktx[k][x] = fenzi;
}
for (int x = 0; x < X; x++) {// 词
theta_ktx[k][x] /= fenmu;
}


// 更新theta_kry
fenmu = 0;
for (int y = 0; y < Y; y++) {// 词
double fenzi = 1;
for (int i = 0; i < D; i++) {// 文档
fenzi += matrixB[i][y] * g_n_k[i][k];
}
fenmu += fenzi;
theta_kry[k][y] = fenzi;
}
for (int y = 0; y < Y; y++) {// 词
theta_kry[k][y] /= fenmu;
}
}


System.out.println("    start update g_n_k ...");
// 计算g_n_k
for (int i = 0; i < D; i++) {// 文档
double max = Double.NEGATIVE_INFINITY;
for (int k = 0; k < K; k++) {// 主题,计算p(dn│theta_k )
double log_ret = 0;
for (int x = 0; x < X; x++) {
log_ret += Math.log(theta_ktx[k][x]) * matrixA[i][x];
}
for (int y = 0; y < Y; y++) {
log_ret += Math.log(theta_kry[k][y]) * matrixB[i][y];
}
if (max < log_ret) {
max = log_ret;
}
g_n_k[i][k] = log_ret;
}
double sum = 0;
for (int k = 0; k < K; k++) {// 主题,计算g(n,k)
g_n_k[i][k] = Math.exp(g_n_k[i][k] - max) * pi_k[k];
sum += g_n_k[i][k];
}
for (int k = 0; k < K; k++) {// 主题概率归一化
g_n_k[i][k] /= sum;
}
}


System.out.println("end itr: " + itr + " ...");
}


// 输出结果
out(D, K, matrixA, g_n_k);
}




// 读取数据
static int[][] getData(String path) throws IOException {
System.out.println("start getData " + path + " ...");
BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(path)), "utf-8"));


String line;
List<String[]> ls = new ArrayList<String[]>();
int maxV = 0;
while ((line = br.readLine()) != null) {
if (!line.startsWith("@")) {
String[] es = line.split(",");
if (es.length > 1) {
if (es.length > maxV) {
maxV = es.length;
}
ls.add(es);
}
}
}


int[][] matrix = new int[ls.size()][maxV];
for (int i = 0; i < ls.size(); i++) {
String[] es = ls.get(i);
for (int j = 0; j < es.length; j++) {
int val = (int) Double.parseDouble(es[j]);
matrix[i][j] = val;
}
}
System.out.println("end getData " + path + " ...");
return matrix;
}


// // 初始lamda_n_t和lamda_r_y
// static BigInteger[] initNRT(int D, int[][] matrix) {
// BigInteger[] lamda_n = new BigInteger[D];
// int len = matrix[0].length - 1;
// for (int i = 0; i < D; i++) {
// int sum = 0;
// BigInteger fenmu = BigInteger.ONE;
// for (int j = 0; j < len; j++) {
// sum += matrix[i][j];
// fenmu = fenmu.multiply(BigInteger_JieCheng.bigFactorial(matrix[i][j]));
// }
// BigInteger fenzi = BigInteger_JieCheng.bigFactorial(sum);
// lamda_n[i] = fenzi.divide(fenmu);
// }
// return lamda_n;
// }


static double[] initPIK(int K) {
double[] pi_k = new double[K];
double sum = 0;
for (int i = 0; i < K; i++) {
pi_k[i] = Math.random();
sum += pi_k[i];
}
for (int i = 0; i < K; i++) {
pi_k[i] /= sum;
}
return pi_k;
}


static double[][] initTheta_kt(int K, int len) {
double[][] theta_kt = new double[K][len];
for (int i = 0; i < K; i++) {// 主题
double sum = 0;
for (int j = 0; j < len; j++) {// 词
theta_kt[i][j] = Math.random();
sum += theta_kt[i][j];
}
for (int j = 0; j < len; j++) {// 词
theta_kt[i][j] /= sum;
}
}
return theta_kt;
}


static double[][] init_pdn_theta_k(int D, int K) {
double[][] pdn_theta_k = new double[D][K];
for (int i = 0; i < D; i++) {// 文档
double val = (double) 1 / K;
for (int k = 0; k < K; k++) {// 主题
pdn_theta_k[i][k] = val;
}
}
return pdn_theta_k;
}


static double[][] init_gnk(int D, int K) {
double[][] g_n_k = new double[D][K];
for (int i = 0; i < D; i++) {// 文档
double sum = 0;
for (int k = 0; k < K; k++) {// 主题
g_n_k[i][k] = Math.random();
sum += g_n_k[i][k];
}
for (int k = 0; k < K; k++) {// 主题
g_n_k[i][k] /= sum;
}
}
return g_n_k;
}


static void out(int D, int K, int[][] matrixA, double[][] g_n_k) {
int[][] retDistribution = new int[K][K];// 文档主题的结果分布表
for (int i = 0; i < K; i++) {// 主题
for (int j = 0; j < K; j++) {// 主题
retDistribution[i][j] = 0;
}
}


for (int i = 0; i < D; i++) {// 文档
System.out.println();
for (int k = 0; k < K; k++) {// 主题
System.out.print("," + g_n_k[i][k]);
}
}
System.out.println();


for (int i = 0; i < D; i++) {// 文档
double max_g_n_k = 0;
int cur_k = 0;
for (int k = 0; k < K; k++) {// 主题
if (g_n_k[i][k] > max_g_n_k) {
max_g_n_k = g_n_k[i][k];
cur_k = k;
}
}
retDistribution[cur_k][matrixA[i][matrixA[i].length - 1]]++;
}


// 输出结果
System.out.println("=============================");
System.out.print("\t");
for (int i = 0; i < K; i++) {// 主题
System.out.print("cls" + i + "\t");
}
System.out.println();


for (int i = 0; i < K; i++) {// 主题
System.out.print("newCls" + i);
for (int j = 0; j < K; j++) {// 主题
System.out.print("\t" + retDistribution[i][j]);
}
System.out.println();
}
System.out.println("=============================");
System.out.println("the end.");
}

}




===============================================

Multi_EM



package wangfang_em;


import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;


public class Multi_EM {


public static void main(String[] args) {
System.out.println("start execute.");
try {
long start = System.currentTimeMillis();
run();
long end = System.currentTimeMillis();
System.out.println("execute time: " + (double) (end - start) / 1000 + " s.");
} catch (IOException e) {
e.printStackTrace();
}
}


static void run() throws IOException {
int[][] matrixA = getData(Config.MATRIX_A_PATH);
// int[][] matrixB = getData(Config.MATRIX_B_PATH);


int K = Config.K;// 主题数


int D = matrixA.length;// 文档数,最后一位存储所属的类
int X = matrixA[0].length - 1;// 词数
// int Y = matrixB[0].length - 1;// 词数


int maxItr = Config.MAX_ITR;


System.out.println("docsNumber:" + D + ", X:" + X + ", K:" + K + ", maxItr:" + maxItr);


// double[] lamda_n_T = initNRT(D, matrixA);
// double[] lamda_n_R = initNRT(D, matrixB);


double[] pi_k = initPIK(K);
double[][] theta_ktx = initTheta_kt(K, X);


double[][] g_n_k = init_gnk(D, K);


for (int itr = 0; itr < maxItr; itr++) {// 迭代
System.out.println("start itr: " + itr + " ...");


for (int k = 0; k < K; k++) {// 主题
// 更新pi
double sum_gnk = 0;
for (int i = 0; i < D; i++) {// 文档
sum_gnk += g_n_k[i][k];
}
pi_k[k] = sum_gnk / D;


// 更新theta_ktx
double fenmu = 0;
for (int x = 0; x < X; x++) {// 词
double fenzi = 1;
for (int i = 0; i < D; i++) {// 文档
fenzi += matrixA[i][x] * g_n_k[i][k];
}
fenmu += fenzi;
theta_ktx[k][x] = fenzi;
}
for (int x = 0; x < X; x++) {// 词
theta_ktx[k][x] /= fenmu;
}
}


System.out.println("    start update g_n_k ...");
// 计算g_n_k
for (int i = 0; i < D; i++) {// 文档
double max = Double.NEGATIVE_INFINITY;
for (int k = 0; k < K; k++) {// 主题,计算p(dn│theta_k )
double log_ret = 0;
for (int x = 0; x < X; x++) {
log_ret += Math.log(theta_ktx[k][x]) * matrixA[i][x];
}
if (max < log_ret) {
max = log_ret;
}
g_n_k[i][k] = log_ret;
}
double sum = 0;
for (int k = 0; k < K; k++) {// 主题,计算g(n,k)
g_n_k[i][k] = Math.exp(g_n_k[i][k] - max) * pi_k[k];
sum += g_n_k[i][k];
}
for (int k = 0; k < K; k++) {// 主题概率归一化
g_n_k[i][k] /= sum;
}
}


System.out.println("end itr: " + itr + " ...");
}


// 输出结果
out(D, K, matrixA, g_n_k);
}




// 读取数据
static int[][] getData(String path) throws IOException {
System.out.println("start getData " + path + " ...");
BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(path)), "utf-8"));


String line;
List<String[]> ls = new ArrayList<String[]>();
int maxV = 0;
while ((line = br.readLine()) != null) {
if (!line.startsWith("@")) {
String[] es = line.split(",");
if (es.length > 1) {
if (es.length > maxV) {
maxV = es.length;
}
ls.add(es);
}
}
}


int[][] matrix = new int[ls.size()][maxV];
for (int i = 0; i < ls.size(); i++) {
String[] es = ls.get(i);
for (int j = 0; j < es.length; j++) {
int val = (int) Double.parseDouble(es[j]);
matrix[i][j] = val;
}
}
System.out.println("end getData " + path + " ...");
return matrix;
}


// // 初始lamda_n_t和lamda_r_y
// static BigInteger[] initNRT(int D, int[][] matrix) {
// BigInteger[] lamda_n = new BigInteger[D];
// int len = matrix[0].length - 1;
// for (int i = 0; i < D; i++) {
// int sum = 0;
// BigInteger fenmu = BigInteger.ONE;
// for (int j = 0; j < len; j++) {
// sum += matrix[i][j];
// fenmu = fenmu.multiply(BigInteger_JieCheng.bigFactorial(matrix[i][j]));
// }
// BigInteger fenzi = BigInteger_JieCheng.bigFactorial(sum);
// lamda_n[i] = fenzi.divide(fenmu);
// }
// return lamda_n;
// }


static double[] initPIK(int K) {
double[] pi_k = new double[K];
double sum = 0;
for (int i = 0; i < K; i++) {
pi_k[i] = Math.random();
sum += pi_k[i];
}
for (int i = 0; i < K; i++) {
pi_k[i] /= sum;
}
return pi_k;
}


static double[][] initTheta_kt(int K, int len) {
double[][] theta_kt = new double[K][len];
for (int i = 0; i < K; i++) {// 主题
double sum = 0;
for (int j = 0; j < len; j++) {// 词
theta_kt[i][j] = Math.random();
sum += theta_kt[i][j];
}
for (int j = 0; j < len; j++) {// 词
theta_kt[i][j] /= sum;
}
}
return theta_kt;
}


static double[][] init_pdn_theta_k(int D, int K) {
double[][] pdn_theta_k = new double[D][K];
for (int i = 0; i < D; i++) {// 文档
double val = (double) 1 / K;
for (int k = 0; k < K; k++) {// 主题
pdn_theta_k[i][k] = val;
}
}
return pdn_theta_k;
}


static double[][] init_gnk(int D, int K) {
double[][] g_n_k = new double[D][K];
for (int i = 0; i < D; i++) {// 文档
double sum = 0;
for (int k = 0; k < K; k++) {// 主题
g_n_k[i][k] = Math.random();
sum += g_n_k[i][k];
}
for (int k = 0; k < K; k++) {// 主题
g_n_k[i][k] /= sum;
}
}
return g_n_k;
}


static void out(int D, int K, int[][] matrixA, double[][] g_n_k) {
int[][] retDistribution = new int[K][K];// 文档主题的结果分布表
for (int i = 0; i < K; i++) {// 主题
for (int j = 0; j < K; j++) {// 主题
retDistribution[i][j] = 0;
}
}


for (int i = 0; i < D; i++) {// 文档
System.out.println();
for (int k = 0; k < K; k++) {// 主题
System.out.print("," + g_n_k[i][k]);
}
}
System.out.println();


for (int i = 0; i < D; i++) {// 文档
double max_g_n_k = 0;
int cur_k = 0;
for (int k = 0; k < K; k++) {// 主题
if (g_n_k[i][k] > max_g_n_k) {
max_g_n_k = g_n_k[i][k];
cur_k = k;
}
}
retDistribution[cur_k][matrixA[i][matrixA[i].length - 1]]++;
}


// 输出结果
System.out.println("=============================");
System.out.print("\t");
for (int i = 0; i < K; i++) {// 主题
System.out.print("cls" + i + "\t");
}
System.out.println();


for (int i = 0; i < K; i++) {// 主题
System.out.print("newCls" + i);
for (int j = 0; j < K; j++) {// 主题
System.out.print("\t" + retDistribution[i][j]);
}
System.out.println();
}
System.out.println("=============================");
System.out.println("the end.");
}
}





=====================================

Multi_Bernoulli_EM



package wangfang_em;


import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;


public class Multi_Bernoulli_EM {


public static void main(String[] args) {
System.out.println("start execute.");
try {
long start = System.currentTimeMillis();
run();
long end = System.currentTimeMillis();
System.out.println("execute time: " + (double) (end - start) / 1000 + " s.");
} catch (IOException e) {
e.printStackTrace();
}
}


static void run() throws IOException {
int[][] matrixA = getData(Config.MATRIX_A_PATH);
int[][] matrixB = getData(Config.MATRIX_B_PATH);


int K = Config.K;// 主题数


int D = matrixA.length;// 文档数,最后一位存储所属的类
int X = matrixA[0].length - 1;// 词数
int Y = matrixB[0].length - 1;// 词数


int maxItr = Config.MAX_ITR;


System.out.println("docsNumber:" + D + ", X:" + X + ", Y:" + Y + ", K:" + K + ", maxItr:" + maxItr);


// double[] lamda_n_T = initNRT(D, matrixA);
// double[] lamda_n_R = initNRT(D, matrixB);


double[] pi_k = initPIK(K);
double[][] theta_ktx = initTheta_kt(K, X);
double[][] theta_kry = initTheta_kt(K, Y);


double[][] g_n_k = init_gnk(D, K);


for (int itr = 0; itr < maxItr; itr++) {// 迭代
System.out.println("start itr: " + itr + " ...");


for (int k = 0; k < K; k++) {// 主题
// 更新pi
double sum_gnk = 0;
for (int i = 0; i < D; i++) {// 文档
sum_gnk += g_n_k[i][k];
}
pi_k[k] = sum_gnk / D;


// 更新theta_ktx
double fenmu = 0;
for (int x = 0; x < X; x++) {// 词
double fenzi = 1;
for (int i = 0; i < D; i++) {// 文档
fenzi += matrixA[i][x] * g_n_k[i][k];
}
fenmu += fenzi;
theta_ktx[k][x] = fenzi;
}
for (int x = 0; x < X; x++) {// 词
theta_ktx[k][x] /= fenmu;
}


// 更新theta_kry

for (int y = 0; y < Y; y++) {// 词
double fenzi = 1;
fenmu = 2;
for (int i = 0; i < D; i++) {// 文档
fenzi += matrixB[i][y] * g_n_k[i][k];
fenmu += matrixB[i][y];
}
theta_kry[k][y] = fenzi/fenmu;
}
}


System.out.println("    start update g_n_k ...");
// 计算g_n_k
for (int i = 0; i < D; i++) {// 文档
double max = Double.NEGATIVE_INFINITY;
for (int k = 0; k < K; k++) {// 主题,计算p(dn│theta_k )
double log_ret = 0;
for (int x = 0; x < X; x++) {
log_ret += Math.log(theta_ktx[k][x]) * matrixA[i][x];
}
for (int y = 0; y < Y; y++) {
log_ret += Math.log(theta_kry[k][y]) * matrixB[i][y];
log_ret += Math.log((1-theta_kry[k][y])) * (1-matrixB[i][y]);
}
if (max < log_ret) {
max = log_ret;
}
g_n_k[i][k] = log_ret;
}
double sum = 0;
for (int k = 0; k < K; k++) {// 主题,计算g(n,k)
g_n_k[i][k] = Math.exp(g_n_k[i][k] - max) * pi_k[k];
sum += g_n_k[i][k];
}
for (int k = 0; k < K; k++) {// 主题概率归一化
g_n_k[i][k] /= sum;
}
}


System.out.println("end itr: " + itr + " ...");
}


// 输出结果
out(D, K, matrixA, g_n_k);
}




// 读取数据
static int[][] getData(String path) throws IOException {
System.out.println("start getData " + path + " ...");
BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(path)), "utf-8"));


String line;
List<String[]> ls = new ArrayList<String[]>();
int maxV = 0;
while ((line = br.readLine()) != null) {
if (!line.startsWith("@")) {
String[] es = line.split(",");
if (es.length > 1) {
if (es.length > maxV) {
maxV = es.length;
}
ls.add(es);
}
}
}


int[][] matrix = new int[ls.size()][maxV];
for (int i = 0; i < ls.size(); i++) {
String[] es = ls.get(i);
for (int j = 0; j < es.length; j++) {
int val = (int) Double.parseDouble(es[j]);
matrix[i][j] = val;
}
}
System.out.println("end getData " + path + " ...");
return matrix;
}


// // 初始lamda_n_t和lamda_r_y
// static BigInteger[] initNRT(int D, int[][] matrix) {
// BigInteger[] lamda_n = new BigInteger[D];
// int len = matrix[0].length - 1;
// for (int i = 0; i < D; i++) {
// int sum = 0;
// BigInteger fenmu = BigInteger.ONE;
// for (int j = 0; j < len; j++) {
// sum += matrix[i][j];
// fenmu = fenmu.multiply(BigInteger_JieCheng.bigFactorial(matrix[i][j]));
// }
// BigInteger fenzi = BigInteger_JieCheng.bigFactorial(sum);
// lamda_n[i] = fenzi.divide(fenmu);
// }
// return lamda_n;
// }


static double[] initPIK(int K) {
double[] pi_k = new double[K];
double sum = 0;
for (int i = 0; i < K; i++) {
pi_k[i] = Math.random();
sum += pi_k[i];
}
for (int i = 0; i < K; i++) {
pi_k[i] /= sum;
}
return pi_k;
}


static double[][] initTheta_kt(int K, int len) {
double[][] theta_kt = new double[K][len];
for (int i = 0; i < K; i++) {// 主题
double sum = 0;
for (int j = 0; j < len; j++) {// 词
theta_kt[i][j] = Math.random();
sum += theta_kt[i][j];
}
for (int j = 0; j < len; j++) {// 词
theta_kt[i][j] /= sum;
}
}
return theta_kt;
}


static double[][] init_pdn_theta_k(int D, int K) {
double[][] pdn_theta_k = new double[D][K];
for (int i = 0; i < D; i++) {// 文档
double val = (double) 1 / K;
for (int k = 0; k < K; k++) {// 主题
pdn_theta_k[i][k] = val;
}
}
return pdn_theta_k;
}


static double[][] init_gnk(int D, int K) {
double[][] g_n_k = new double[D][K];
for (int i = 0; i < D; i++) {// 文档
double sum = 0;
for (int k = 0; k < K; k++) {// 主题
g_n_k[i][k] = Math.random();
sum += g_n_k[i][k];
}
for (int k = 0; k < K; k++) {// 主题
g_n_k[i][k] /= sum;
}
}
return g_n_k;
}


static void out(int D, int K, int[][] matrixA, double[][] g_n_k) {
int[][] retDistribution = new int[K][K];// 文档主题的结果分布表
for (int i = 0; i < K; i++) {// 主题
for (int j = 0; j < K; j++) {// 主题
retDistribution[i][j] = 0;
}
}


for (int i = 0; i < D; i++) {// 文档
System.out.println();
for (int k = 0; k < K; k++) {// 主题
System.out.print("," + g_n_k[i][k]);
}
}
System.out.println();


for (int i = 0; i < D; i++) {// 文档
double max_g_n_k = 0;
int cur_k = 0;
for (int k = 0; k < K; k++) {// 主题
if (g_n_k[i][k] > max_g_n_k) {
max_g_n_k = g_n_k[i][k];
cur_k = k;
}
}
retDistribution[cur_k][matrixA[i][matrixA[i].length - 1]]++;
}


// 输出结果
System.out.println("=============================");
System.out.print("\t");
for (int i = 0; i < K; i++) {// 主题
System.out.print("cls" + i + "\t");
}
System.out.println();


for (int i = 0; i < K; i++) {// 主题
System.out.print("newCls" + i);
for (int j = 0; j < K; j++) {// 主题
System.out.print("\t" + retDistribution[i][j]);
}
System.out.println();
}
System.out.println("=============================");
System.out.println("the end.");
}
}




============================================

Bernoulli_EM


package wangfang_em;


import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;


public class Bernoulli_EM {


public static void main(String[] args) {
System.out.println("start execute.");
try {
long start = System.currentTimeMillis();
run();
long end = System.currentTimeMillis();
System.out.println("execute time: " + (double) (end - start) / 1000 + " s.");
} catch (IOException e) {
e.printStackTrace();
}
}


static void run() throws IOException {
int[][] matrixA = getData(Config.MATRIX_A_PATH);
// int[][] matrixB = getData(Config.MATRIX_B_PATH);


int K = Config.K;// 主题数


int D = matrixA.length;// 文档数,最后一位存储所属的类
int X = matrixA[0].length - 1;// 词数
// int Y = matrixB[0].length - 1;// 词数


int maxItr = Config.MAX_ITR;


System.out.println("docsNumber:" + D + ", X:" + X + ", K:" + K + ", maxItr:" + maxItr);


double[] pi_k = initPIK(K);
double[][] theta_ktx = initTheta_kt(K, X);


double[][] g_n_k = init_gnk(D, K);


for (int itr = 0; itr < maxItr; itr++) {// 迭代
System.out.println("start itr: " + itr + " ...");


for (int k = 0; k < K; k++) {// 主题
// 更新pi
double sum_gnk = 0;
for (int i = 0; i < D; i++) {// 文档
sum_gnk += g_n_k[i][k];
}
pi_k[k] = sum_gnk / D;


// 更新theta_ktx
for (int x = 0; x < X; x++) {// 词
double fenzi = 1;
double fenmu = 2;
for (int i = 0; i < D; i++) {// 文档
fenzi += matrixA[i][x] * g_n_k[i][k];
fenmu += matrixA[i][x];
}
theta_ktx[k][x] = fenzi/fenmu;
}
}


System.out.println("    start update g_n_k ...");
// 计算g_n_k
for (int i = 0; i < D; i++) {// 文档
double max = Double.NEGATIVE_INFINITY;
for (int k = 0; k < K; k++) {// 主题,计算p(dn│theta_k )
double log_ret = 0;
for (int x = 0; x < X; x++) {
log_ret += Math.log(theta_ktx[k][x]) * matrixA[i][x];
log_ret += Math.log((1-theta_ktx[k][x])) * (1-matrixA[i][x]);
}
if (max < log_ret) {
max = log_ret;
}
g_n_k[i][k] = log_ret;
}
double sum = 0;
for (int k = 0; k < K; k++) {// 主题,计算g(n,k)
g_n_k[i][k] = Math.exp(g_n_k[i][k] - max) * pi_k[k];
sum += g_n_k[i][k];
}
for (int k = 0; k < K; k++) {// 主题概率归一化
g_n_k[i][k] /= sum;
}
}


System.out.println("end itr: " + itr + " ...");
}


// 输出结果
out(D, K, matrixA, g_n_k);
}




// 读取数据
static int[][] getData(String path) throws IOException {
System.out.println("start getData " + path + " ...");
BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(path)), "utf-8"));


String line;
List<String[]> ls = new ArrayList<String[]>();
int maxV = 0;
while ((line = br.readLine()) != null) {
if (!line.startsWith("@")) {
String[] es = line.split(",");
if (es.length > 1) {
if (es.length > maxV) {
maxV = es.length;
}
ls.add(es);
}
}
}


int[][] matrix = new int[ls.size()][maxV];
for (int i = 0; i < ls.size(); i++) {
String[] es = ls.get(i);
for (int j = 0; j < es.length; j++) {
int val = (int) Double.parseDouble(es[j]);
matrix[i][j] = val;
}
}
System.out.println("end getData " + path + " ...");
return matrix;
}


// // 初始lamda_n_t和lamda_r_y
// static BigInteger[] initNRT(int D, int[][] matrix) {
// BigInteger[] lamda_n = new BigInteger[D];
// int len = matrix[0].length - 1;
// for (int i = 0; i < D; i++) {
// int sum = 0;
// BigInteger fenmu = BigInteger.ONE;
// for (int j = 0; j < len; j++) {
// sum += matrix[i][j];
// fenmu = fenmu.multiply(BigInteger_JieCheng.bigFactorial(matrix[i][j]));
// }
// BigInteger fenzi = BigInteger_JieCheng.bigFactorial(sum);
// lamda_n[i] = fenzi.divide(fenmu);
// }
// return lamda_n;
// }


static double[] initPIK(int K) {
double[] pi_k = new double[K];
double sum = 0;
for (int i = 0; i < K; i++) {
pi_k[i] = Math.random();
sum += pi_k[i];
}
for (int i = 0; i < K; i++) {
pi_k[i] /= sum;
}
return pi_k;
}


static double[][] initTheta_kt(int K, int len) {
double[][] theta_kt = new double[K][len];
for (int i = 0; i < K; i++) {// 主题
double sum = 0;
for (int j = 0; j < len; j++) {// 词
theta_kt[i][j] = Math.random();
sum += theta_kt[i][j];
}
for (int j = 0; j < len; j++) {// 词
theta_kt[i][j] /= sum;
}
}
return theta_kt;
}


static double[][] init_pdn_theta_k(int D, int K) {
double[][] pdn_theta_k = new double[D][K];
for (int i = 0; i < D; i++) {// 文档
double val = (double) 1 / K;
for (int k = 0; k < K; k++) {// 主题
pdn_theta_k[i][k] = val;
}
}
return pdn_theta_k;
}


static double[][] init_gnk(int D, int K) {
double[][] g_n_k = new double[D][K];
for (int i = 0; i < D; i++) {// 文档
double sum = 0;
for (int k = 0; k < K; k++) {// 主题
g_n_k[i][k] = Math.random();
sum += g_n_k[i][k];
}
for (int k = 0; k < K; k++) {// 主题
g_n_k[i][k] /= sum;
}
}
return g_n_k;
}


static void out(int D, int K, int[][] matrixA, double[][] g_n_k) {
int[][] retDistribution = new int[K][K];// 文档主题的结果分布表
for (int i = 0; i < K; i++) {// 主题
for (int j = 0; j < K; j++) {// 主题
retDistribution[i][j] = 0;
}
}


for (int i = 0; i < D; i++) {// 文档
System.out.println();
for (int k = 0; k < K; k++) {// 主题
System.out.print("," + g_n_k[i][k]);
}
}
System.out.println();


for (int i = 0; i < D; i++) {// 文档
double max_g_n_k = 0;
int cur_k = 0;
for (int k = 0; k < K; k++) {// 主题
if (g_n_k[i][k] > max_g_n_k) {
max_g_n_k = g_n_k[i][k];
cur_k = k;
}
}
retDistribution[cur_k][matrixA[i][matrixA[i].length - 1]]++;
}


// 输出结果
System.out.println("=============================");
System.out.print("\t");
for (int i = 0; i < K; i++) {// 主题
System.out.print("cls" + i + "\t");
}
System.out.println();


for (int i = 0; i < K; i++) {// 主题
System.out.print("newCls" + i);
for (int j = 0; j < K; j++) {// 主题
System.out.print("\t" + retDistribution[i][j]);
}
System.out.println();
}
System.out.println("=============================");
System.out.println("the end.");
}
}

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值