HMM model 例子 Biased coins


输入

$ head toss1.txt
H
T
H
T
T
H
T
H
T
T
H for 正面, T for 反面。
可能是正常的 coin,可能是biased coin(HT 的概率非0.5 0.5),二者按照一定几率转换,几率不知道。


输出
TIME    TOSS    P(FAIR) P(BIAS) MLSTATE
1       H       0.3131  0.6869  BIASED
2       T       0.9991  0.0009  FAIR
3       H       0.9788  0.0212  FAIR
4       T       0.9978  0.0022  FAIR
5       T       0.9978  0.0022  FAIR
6       H       0.9788  0.0212  FAIR
7       T       0.9978  0.0022  FAIR
8       H       0.9788  0.0212  FAIR
9       T       0.9978  0.0022  FAIR
10      T       0.9978  0.0022  FAIR


根据 sph.umich  Biostatistics 615课程课件改编。


main.cpp

这里假定训练参数到 0.001截止,否则会过度训练,最多通过 EM算法迭代100次。

#include <iostream>
#include <iomanip>
#include "HMM615.h"

using namespace std;

int main(int argc, char** argv) {
	double tran_threshold = 0.001;
	vector<int> toss; string tok;
	while( cin >> tok ) {
		if ( tok == "H" ) toss.push_back(0);
		else if ( tok == "T" ) toss.push_back(1);
		else {
			cerr << "Cannot recognize input " << tok << endl;
			return -1;
		}
	}

	int T = toss.size();
	HMM615 hmm(2, 2, T);
	hmm.trans.data[0][0] = 0.95; hmm.trans.data[0][1] = 0.05;
	hmm.trans.data[1][0] = 0.2; hmm.trans.data[1][1] = 0.8;
	hmm.emis.data[0][0] = 0.5; hmm.emis.data[0][1] = 0.5;
	hmm.emis.data[1][0] = 0.9; hmm.emis.data[1][1] = 0.1;
	hmm.pis[0] = 0.5; hmm.pis[1] = 0.5; hmm.outs = toss;
	hmm.viterbi();
	hmm.BaumWelch();

	for (int i=1; i<100; ++i){
		if ( hmm.TRANS.data[0][0] < tran_threshold || hmm.TRANS.data[0][1] < tran_threshold || hmm.TRANS.data[1][0] < tran_threshold || hmm.TRANS.data[1][1] < tran_threshold || hmm.EMIS.data[0][0] < tran_threshold || hmm.EMIS.data[0][1] < tran_threshold || hmm.EMIS.data[1][0] < tran_threshold || hmm.EMIS.data[1][1] < tran_threshold  ){
			break;
		}
		hmm.trans.data[0][0] = hmm.TRANS.data[0][0];
  		hmm.trans.data[0][1] = hmm.TRANS.data[0][1];
		hmm.trans.data[1][0] = hmm.TRANS.data[1][0];
		hmm.trans.data[1][1] = hmm.TRANS.data[1][1];

		hmm.emis.data[0][0] = hmm.EMIS.data[0][0];
		hmm.emis.data[0][1] = hmm.EMIS.data[0][1];
		hmm.emis.data[1][0] = hmm.EMIS.data[1][0];
		hmm.emis.data[1][1] = hmm.EMIS.data[1][1];

		hmm.zero_sigma();
		hmm.viterbi();
		hmm.BaumWelch();
	}
	cout << "TIME\tTOSS\tP(FAIR)\tP(BIAS)\tMLSTATE" << endl;
	cout << setiosflags(ios::fixed) << setprecision(4);
	for(int t=0; t < T; ++t) {
		cout << t+1 << "\t" << (toss[t] == 0 ? "H" : "T") << "\t" << hmm.gammas.data[t][0] << "\t" << hmm.gammas.data[t][1] << "\t" << (hmm.path[t] == 0 ? "FAIR" : "BIASED" ) << endl;
	}
	cout << hmm.TRANS.data[0][0] << "\t" << hmm.TRANS.data[0][1] << endl;
	cout << hmm.TRANS.data[1][0] << "\t" << hmm.TRANS.data[1][1] << endl;
	cout << hmm.EMIS.data[0][0]  << "\t" << hmm.EMIS.data[0][1]  << endl;
	cout << hmm.EMIS.data[1][0]  << "\t" << hmm.EMIS.data[1][1]  << endl;
	return 0;
}

HMM615.h 

#ifndef __HMM_615_H
#define __HMM_615_H
#include "Matrix615.h"
#include "MatrixTrible.h"
#include <cmath>

class HMM615 {
	public:
	// parameters
		int nStates; // n : number of possible states
		int nObs; // m : number of possible output values
		int nTimes; // t : number of time slots with observations
		std::vector<double> pis; // initial states
		std::vector<int> outs; // observed outcomes
		Matrix615<double> trans; // trans[i][j] corresponds to A_{ij}
		Matrix615<double> emis;
		Matrix615<double> TRANS;	// Revised trans and emis
		Matrix615<double> EMIS;	//

	// storages for dynamic programming
	Matrix615<double> alphas, betas, gammas, deltas;
	Matrix615<int> phis;
	MatrixTrible<double> sigma;
	std::vector<int> path;
	HMM615(int states, int obs, int times) : nStates(states), nObs(obs), nTimes(times), trans(states, states, 0), emis(states, obs, 0), alphas(times, states, 0), betas(times, states, 0),gammas(times, states, 0), deltas(times, states, 0),sigma(times,states,states,0),phis(times, states, 0),TRANS(times,states,0),EMIS(times,states,0)
	{
	   pis.resize(nStates);
	   path.resize(nTimes);
	}
	void forward(); // given below
	void backward(); //
	void forwardBackward(); // given below
	void BaumWelch();
	void zero_sigma();
	void viterbi(); //
};
#endif // __HMM_615_H

void HMM615::zero_sigma(){
	for(int t=0; t < nTimes-1; ++t) {
		for(int i=0; i < nStates; ++i){
			for (int j=0; j < nStates; ++j){
				sigma.data[t][i][j] = 0;
			}
		}
	}
}

void HMM615::forward(){
   for(int i=0; i < nStates; ++i){
		double tmp = std::log(pis[i]) + std::log(emis.data[i][outs[0]]);
      alphas.data[0][i] = std::exp(tmp);
   }
   for(int t=1; t < nTimes; ++t){
      for(int i=0; i < nStates; ++i){
         alphas.data[t][i] = 0;
         for(int j=0; j < nStates; ++j) {
				double tmp = std::log( alphas.data[t-1][j] ) + std::log( trans.data[j][i] ) + std::log( emis.data[i][outs[t]]) ;
            alphas.data[t][i] += std::exp(tmp);
         }
      }
   }
}

void HMM615::backward() {
	for(int i=0; i < nStates; ++i) {
		betas.data[nTimes-1][i] = 1;
	}
	for(int t=nTimes-2; t >=0; --t) {
		for(int i=0; i < nStates; ++i) {
			betas.data[t][i] = 0;
			for(int j=0; j < nStates; ++j) {
				double tmp = std::log( betas.data[t+1][j] ) + std::log( trans.data[i][j] ) + std::log( emis.data[j][outs[t+1]]) ;
				betas.data[t][i] += std::exp(tmp);
			}
		}
	}
}

void HMM615::forwardBackward() {
	forward();
	backward();
	for(int t=0; t < nTimes; ++t) {
		double sum = 0;
		for(int i=0; i < nStates; ++i) {
			double tmp = std::log( alphas.data[t][i] ) + std::log( betas.data[t][i] );
			sum += std::exp(tmp);
		}
		for(int i=0; i < nStates; ++i) {
			double tmp = std::log( alphas.data[t][i] ) + std::log( betas.data[t][i] ) - std::log( sum );
			gammas.data[t][i] = std::exp(tmp);
		}
	}
	for(int t=0; t < nTimes-1; ++t) {
		double sum = 0;
		for(int i=0; i < nStates; ++i){
			for (int j=0; j < nStates; ++j){
				double tmp = std::log( alphas.data[t][i] ) + std::log( trans.data[i][j] ) + std::log( emis.data[j][outs[t+1]] ) + std::log( betas.data[t+1][j] );
				sum += std::exp(tmp);
			}
		}
		for(int i=0; i < nStates; ++i){
			for (int j=0; j < nStates; ++j){
				double tmp = std::log( alphas.data[t][i] ) + std::log( trans.data[i][j] ) + std::log( emis.data[j][outs[t+1]] ) + std::log( betas.data[t+1][j] ) - std::log( sum );
				sigma.data[t][i][j] += std::exp(tmp);
			}
		}
	}
}

void HMM615::BaumWelch() {
	forwardBackward();
	for (int i=0; i<nStates; ++i){
		double sum_gamma1 = 0;
		for (int t=0; t<nTimes-1; ++t){
			sum_gamma1 += gammas.data[t][i];
		}
		for (int j=0; j<nStates; ++j){
			double sum_sigma1 = 0;
			for (int t=0; t<nTimes-1; ++t){
				sum_sigma1 += sigma.data[t][i][j];
			}
			TRANS.data[i][j] = sum_sigma1 / sum_gamma1;
		}
	}

	for (int i=0; i<nStates; ++i){
		for (int k=0; k<nObs; ++k){
			double sum_gamma2 = 0;
			double sum_gamma3 = 0;
			for (int t=0; t<nTimes; ++t){
				sum_gamma2 += gammas.data[t][i];
				if (outs[t]==k){
					sum_gamma3 += gammas.data[t][i];
				}
			}
			EMIS.data[i][k]= sum_gamma3 / sum_gamma2;
		}
	}
}

void HMM615::viterbi() {
	for(int i=0; i < nStates; ++i) {
		deltas.data[0][i] = pis[i] * emis.data[i][ outs[0] ];
	}
	for(int t=1; t < nTimes; ++t) {
		for(int i=0; i < nStates; ++i) {
			int maxIdx = 0;
			double tmp = std::log( deltas.data[t-1][0] ) + std::log( trans.data[0][i] ) + std::log( emis.data[i][ outs[t] ] );
			double maxVal = std::exp(tmp);
			for(int j=1; j < nStates; ++j) {
				double tmp2 = std::log( deltas.data[t-1][j] ) + std::log( trans.data[j][i] ) + std::log( emis.data[i][ outs[t] ] );
				double val  = std::exp( tmp2 );
				if ( val > maxVal ) {
					maxIdx = j;
					maxVal = val;
				}
			}
			deltas.data[t][i] = maxVal;
			phis.data[t][i] = maxIdx;
		}
	}
	double maxDelta = deltas.data[nTimes-1][0];
	path[nTimes-1] = 0;
	for(int i=1; i < nStates; ++i){
		if ( maxDelta < deltas.data[nTimes-1][i] ) {
			maxDelta = deltas.data[nTimes-1][i];
			path[nTimes-i] = i;
		}
	}
	for(int t=nTimes-2; t >= 0; --t){
		path[t] = phis.data[t+1][ path[t+1] ];
	}
}

Matrix615.h

#ifndef __MATRIX_615_H
#define __MATRIX_615_H
#include <vector>

template <class T>
class Matrix615 {
   public:
      std::vector< std::vector<T> > data;
      Matrix615(int nrow, int ncol, T val = 0) {
         data.resize(nrow); // make n rows
         for(int i=0; i < nrow; ++i) {
            data[i].resize(ncol,val); // make n cols with default value val
         }
      }
      int rowNums() { return (int)data.size(); }
      int colNums() { return ( data.size() == 0 ) ? 0 : (int)data[0].size(); }
};
#endif // __MATRIX_615_H

MatrixTrible.h

#ifndef __MATRIX_TRIBLE_H
#define __MATRIX_TRIBLE_H
#include <vector>

template <class T>
class MatrixTrible {
   public:
      std::vector< std::vector< std::vector<T> > > data;
      MatrixTrible(int nrow, int ncol, int ntri, T val = 0) {
         data.resize(nrow); // make n rows
         for(int i=0; i < nrow; ++i) {
            data[i].resize(ncol); // make n cols with default value val
				for (int j=0; j < ncol; ++j){
					data[i][j].resize(ntri,val);
				}
         }
      }
      int rowNums() { return (int)data.size(); }
      int colNums() { return ( data.size() == 0 ) ? 0 : (int)data[0].size(); }
      int triNums() { return ( data.size() == 0 ) ? 0 : ( ( data[0].size() == 0) ? 0 : (int)data[0][0].size() ); }
};
#endif // __MATRIX_TRIBLE_H
Makefile
cc=g++
obj=main.o
exe=hmm
$(exe):$(obj)
	$(cc) -o $(exe) $(obj)

main.o:main.cpp HMM615.h
	$(cc) -c main.cpp

clean:
	rm -rf *.o $(exe)



  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值