logprobs.py
import numpy as np
from math import exp, log
def sumLogProb2(log_prob1, log_prob2):
if(np.isinf(log_prob1) and np.isinf(log_prob2)):
return log_prob1
elif(log_prob1 > log_prob2):
return log_prob1 + log(1 + exp(log_prob2 - log_prob1))
else:
return log_prob2 + log(1 + exp(log_prob1 - log_prob2))
def sumLogProb1(log_probs):
_max = 0
for i in range(0, len(log_probs)):
if(i == 0 or log_probs[i] > _max):
_max = log_probs[i]
if(np.isinf(_max)):
return _max
p = 0.0
for i in range(0, len(log_probs)):
p += exp(log_probs[i] - _max)
if (p == 0):
return -1e20
else:
return _max + log(p)
str2idmap.py
class Str2IdMap():
def __init__(self):
self._to_id = {}
self._to_str = []
def getStr(self, id):
return self._to_str[id]
def getId(self, str):
if(self._to_id.has_key(str) == False):
id = len(self._to_id.items())
self._to_id[str] = id
self._to_str.append(str)
return id
else:
return self._to_id[str]
import str2idmap
import math
from logprobs import *
SMOOTHEDZEROCOUNT=-40
class OneDTable(dict):
def __init__(self):
self._smoothed_zero_count = SMOOTHEDZEROCOUNT
def smoothedZeroCount(self):
return self._smoothed_zero_count
def getValue(self, event):
if(self.has_key(event) == False):
return self.smoothedZeroCount()
else:
return self.get(event)
def add(self, event, count):
if(self.has_key(event) == False):
self[event] = count #
else:
self[event] = sumLogProb2(self.get(event), count)
def rand(self, next):
p = random.uniform(0, 0x7fff) / 0x7fff
total = 0
for key in self.keys:
total = total + exp(self.get(key))
if(total >= p):
next = key
return [next, True]
return [next, False]
tables.py
class TwoDTable(dict):
def __init__(self):
self._possibleContexts = {}
self._backoff = OneDTable()
def get(self, event, context):
if(self.has_key(context) == False):
return self._backoff.getValue(event)
else:
return self[context].getValue(event)
def add(self, context, event, count):
entry = OneDTable()
if(self.has_key(context) == False):
self[context] = entry
else:
entry = self[context]
entry.add(event, count)
#
possCntx = set()
if(self._possibleContexts.has_key(event) == False):
self._possibleContexts[event] = possCntx
else:
possCntx = self._possibleContexts[event]
possCntx.add(context)
def getCntx(self, event):
if(self._possibleContexts.has_key(event) == False):
return 0
else:
return self._possibleContexts[event]
def load(self, input_file, str2id):
while(True):
line = input_file.readline()
if(line == '\n' or line == ''):
break
eles = line.split(' ')
if(len(eles) < 3):
eles = line.split('\t')
p = []
for ele in eles:
if(ele != '\n'):
p.append(ele)
if(float(p[2]) > 0.0):
self.add(str2id.getId(p[0]), str2id.getId(p[1]), math.log(float(p[2])))
def save(self, output_file, str2id):
twodtable_keys = self.keys()
for t_key in twodtable_keys:
vals = self[t_key]
onedtable_keys = vals.keys()
for o_key in onedtable_keys:
output_file.write(str2id.getStr(t_key))
output_file.write(' ')
output_file.write(str2id.getStr(o_key))
output_file.write(' ')
output_file.write(str(exp(vals[o_key])))
output_file.write('\n')
def rand(self, curr, next):
if(self.has_key(curr) == False):
return False
val = self.get(curr)
return val.rand(next)
"""
Copyright ChaoZhong superclocks@163.com
"""
import math
import random
import numpy as np
from logprobs import * #from my code
from tables import OneDTable,TwoDTable#from my code
from str2idmap import Str2IdMap
from sys import stderr
#A transition between two Hmm nodes.
class Transition():
def __init__(self, from_node, to_node, obs ):
self._from = from_node
self._to = to_node
self._obs = obs
if(self._from and self._to):
self._from.outs().append(self)
self._to.ins().append(self)
#A node in an Hmm object.
class HmmNode():
def __init__(self, time, state, hmm):
self._time = time #The time slot for this node.
self._state = state #The hmm that this node belongs to
self._logAlpha = 0 #alpha_t(s) = P(e_1:t, x_t=s);
self._logBeta = 0 #beta_t(s) = P(e_t+1:T |x_t=s);
self._psi = 0 #the last transition of the most probable path that reaches this node
self._hmm = hmm
self._ins = list() #incoming transitions
self._outs = list() #out going transitions
def time(self):
return self._time
def state(self):
return self._state
def setLogAlpha(self, logAlpha):
self._logAlpha = logAlpha
def getLogAlpha(self):
return self._logAlpha
def setLogBeta(self, logBeta):
self._logBeta = logBeta
def getLogBeta(self):
return self._logBeta
def setPsi(self, psi):
self._psi = psi
def getPsi(self):
return self._psi
def ins(self):
return self._ins
def outs(self):
return self._outs
def _print(self):
print('HmmNode')
# Pseudo Counts
class PseudoCounts():
def __init__(self):
self._stateCount = OneDTable()
self._transCount = TwoDTable()
self._emitCount = TwoDTable()
def getStateCount(self):
return self._stateCount
def getTransCount(self):
return self._transCount
def getEmitCount(self):
return self._emitCount
def _print(self, str2id):
print('TRANSTION \n')
#self._transCount.save()
#The possible states at a particular time slot.
class TimeSlot(list): #save HmmNode
def __init__(self):
pass
hmm.py
class Hmm:
def __init__(self,init_state = 0, min_log_prob = 0.0000001 ):
self._init_state = init_state #the initial state
self._transition = TwoDTable()#transition probablity
self._emission = TwoDTable() #emission probablity
self._str2id = Str2IdMap() #mapping between strings and integers
self._time_slots = list() # the time slots
self._min_log_prob = min_log_prob
def loadProbs(self, path):
'''
Read the transition and emission probability tables from the
files NAME.trans and NAME.emit, where NAME is the value of the
variable name.
'''
trans_file_path = path + '.trans'
trans_prob_reader = file(trans_file_path, 'r')
init_state = trans_prob_reader.readline().split('\n')[0]
self._init_state = self._str2id.getId(init_state)
self._transition.load(trans_prob_reader, self._str2id)
emit_file_path = path + '.emit'
emit_prob_reader = file(emit_file_path, 'r')
self._emission.load(emit_prob_reader, self._str2id)
def readSeqs(self, input_file, sequences):
''' Read the training data from the input stream. Each line in the
input stream is an observation sequence.
'''
while(True):
ele_set = list()
line = input_file.readline()
if(line =='' or line == '\n'):
break
ele_in_line = line.split('\n')[0].split(' ')
for ele in ele_in_line:
if(ele != ' '):
ele_set.append(self.getId(ele))
sequences.append(ele_set)
'''
Conversion between the integer id and string form of states and
observations.
'''
def getId(self, str):
return self._str2id.getId(str)
def getStr(self, id):
return self._str2id.getStr(id)
def addObservation(self, o):
stateIds = list()
cntx = self._emission.getCntx(o)
if cntx == 0:
keys = self._emission.keys()
for key in keys:
stateIds.append(key)
else:
for ele in cntx:
stateIds.append(ele)
if (len(self._time_slots) == 0):
t0 = TimeSlot()
t0.append(HmmNode(0, self._init_state,self))
self._time_slots.append(t0)
ts = TimeSlot()
time = len(self._time_slots)
for i in range(0, len(stateIds)):
node = HmmNode(time, stateIds[i] , self)
ts.append(node)
prev = self._time_slots[time - 1]
for it in prev:
possibleSrc = self._transition.getCntx(node.state())
if(len(possibleSrc) > 0 and possibleSrc.__contains__(it.state())):
Transition(it , node, o)
self._time_slots.append(ts)
def getTransProb(self, trans):
return self._transition.get(trans._to.state(), trans._from.state())
def getEmitProb(self, trans):
return self._emission.get(trans._obs, trans._to.state())
#compute the forward probabilities P(e_1:t, X_t=s)
def forward(self):
#computer forward probabilities at time 0
t0 = self._time_slots[0] #TimeSlot
init = t0[0] #HmmNode
init.setLogAlpha(0.0)
#computer forward probabilities at time t using the alpha values for time t-1
for t in range(1, len(self._time_slots)):
ts = self._time_slots[t] #get TimeSlot object at time t
for it in ts: #it is list() type saved the HmmNode
ins = it.ins() #get Transition list
log_probs = list()
for trans in ins:
log_prob = trans._from.getLogAlpha() + self.getTransProb(trans) + self.getEmitProb(trans)
log_probs.append(log_prob)
it.setLogAlpha(sumLogProb1(log_probs))
#compute the backward probabilities P(e_t+1:T | X_t=s)
def backward(self):
T = len(self._time_slots) - 1
if(T < 1): #no observation
return
time = range(0, T + 1)
time.reverse()
for t in time:
ts = self._time_slots[t]
for it in ts:
node = it
if t ==T:
node.setLogBeta(0.0)
else:
outs = node.outs()
log_probs = list()
for i in range(0, len(outs)):
trans = outs[i]
log_prob = trans._to.getLogBeta() + \
self.getTransProb(trans) + self.getEmitProb(trans)
log_probs.append(log_prob)
node.setLogBeta(sumLogProb1(log_probs))
'''
Accumulate pseudo counts using the BaumWelch algorithm. The
return value is the probability of the observations according to
the current model.
'''
def getPseudoCounts(self, counts):
p_of_obs = self.obsProb()
self.backward()
# Compute the pseudo counts of transitions, emissions, and initializations
for t in range(0, len(self._time_slots)):
ts = self._time_slots[t] #get TimeSlot object at time t
'''
P(X_t=s|e_1:T) = alpha_s(t)*beta_s(t)/P(e_t+1:T|e_1:t)
The value sum below is log P(e_t+1:T|e_1:t)
'''
log_probs = list()
for it in ts: #it equ TimeSlot
log_probs.append(it.getLogAlpha() + it.getLogBeta())
_sum = sumLogProb1(log_probs)
#add the pseudo counts into counts
for it in ts:
node = it #HmmNode
#stateCount=P(X_t=s|e_1:T)
state_count = node.getLogAlpha() + node.getLogBeta() - _sum
counts.getStateCount().add(node.state(), state_count)
ins = node.ins() # vector<Transition*>
for k in range(0, len(ins)):
trans = ins[k] #Transition*
_from = trans._from #HmmNode
trans_count = _from.getLogAlpha() + self.getTransProb(trans) + self.getEmitProb(trans) + node.getLogBeta() - p_of_obs
counts.getEmitCount().add(node.state(),trans._obs,trans_count)
outs = node.outs() #vector<Transition*>
for k in range(0, len(outs)):
trans = outs[k] #Transition
to = trans._to #HmmNode
trans_count = node.getLogAlpha() + self.getTransProb(trans) + self.getEmitProb(trans)+to.getLogBeta() - p_of_obs;
counts.getTransCount().add(node.state(), to.state(), trans_count)
return p_of_obs
'''
Find the state sequence (a path) that has the maximum
probability given the sequence of observations:
max_{x_1:T} P(x_1:T | e_1:T);
The return value is the logarithm of the joint probability
of the state sequence and the observation sequence:
log P(x_1:T, e_1:T)
'''
def viterbi(self, path):
#set nodes at time 0 according to initial probabilities.
ts = self._time_slots[0] #TimeSlot
init = ts[0] #HmmNode
init.setLogAlpha(0.0)
#find the best path up to path t.
for t in range(1, len(self._time_slots)):
ts = self._time_slots[t] #ts is vector saved HmmNode
for it in ts: #it is a HmmNode
node = it
ins = node.ins() #ins is vector saved Transition
max_prob = -1e20
best_trans = 0 #best_trans is Transition object
for i in range(0, len(ins)):
trans = ins[i]
log_prob = trans._from.getLogAlpha() + \
self.getTransProb(trans) + self.getEmitProb(trans)
if(best_trans == 0 or max_prob < log_prob):
best_trans = trans
max_prob = log_prob
node.setLogAlpha(max_prob) #store the highest probability in logAlpha
node.setPsi(best_trans) #store the best transition in psi
#Find the best node at time T. It will be the last node in the best path
ts = self._time_slots[len(self._time_slots) - 1]
best = 0 #HmmNode*
for it in ts:
node = it
if (best == 0 or best.getLogAlpha() < node.getLogAlpha()):
best = node
#retrieve the nodes in the best path
nd = best
while(nd):
if(nd.getPsi()):
path.append(nd.getPsi())
nd = nd.getPsi()._from
else:
nd = 0
#reverse the path
i = 0
j = len(path) - 1
while(True):
tmp = path[i]
path[i] = path[j]
path[j] = tmp
if(i >= j):
break
i = i + 1
j = j - 1
return best.getLogAlpha()
def obsProb(self):
#return the logarithm of the observation sequence: log P(e_1:T)
if(len(self._time_slots) < 1):
return 1
self.forward()
last = self._time_slots[len(self._time_slots) - 1]
alphaT = list()
for it in last:
alphaT.append(it.getLogAlpha())
return sumLogProb1(alphaT)
#Clear all time slots to get ready to deal with another sequence.
def reset(self):
for t in range(0, len(self._time_slots)):
self._time_slots.pop()
def updateProbs(self, counts):
self._transition.clear()
self._emission.clear()
keys = counts.getTransCount().keys()
for i in keys:
_from = i
from_count = counts.getStateCount().getValue(_from)
cnts = counts.getTransCount()[i]
cnts_keys = cnts.keys()
for j in cnts_keys:
self._transition.add(_from, j, cnts[j] - from_count)
keys = counts.getEmitCount().keys()
for s in keys:
state = s
state_count = counts.getStateCount().get(state)
cnts = counts.getEmitCount()[s]
cnts_keys = cnts.keys()
for o in cnts_keys:
self._emission.add(state, o, cnts[o] - state_count)
'''
Train the model with the given observation sequences using the
Baum-Welch algorithm.
'''
def baumWelch(self, sequence, max_iterations):
'''Train the model with the given observation sequences using the
Baum-Welch algorithm.
'''
print 'Training with Baum-Welch for up to %d iterations, using %d sequences.' %(max_iterations, len(sequence))
prev_total_log_prob = 0
for k in range(0, max_iterations):
counts = PseudoCounts()
total_log_prob = 0
for i in range(0, len(sequence)):
seq = sequence[i]
for j in range(0, len(seq)):
self.addObservation(seq[j])
#accumulate the pseudo counts
total_log_prob += self.getPseudoCounts(counts)
self.reset()
if((i+1) % 1000 == 0):
print('Processed %d sequences' %(i+1) )
print('Iteration %d total_log_prob = %f' %(k, total_log_prob ))
if((prev_total_log_prob != 0) and (total_log_prob - prev_total_log_prob < 1)):
break
else:
prev_total_log_prob = total_log_prob
self.updateProbs(counts)
def saveProbs(self, name):
if(name == ''):
stderr.write('transition probalities: \n')
self._transition.save(stderr, self._str2id)
stderr.write('-----------------------\n')
stderr.write('emission probabilities: \n')
self._emission.save(stderr, self._str2id)
else:
s = name + '.trans'
trans_prob_writer = file(s, 'w')
trans_prob_writer.write(self._str2id.getStr(self._init_state))
trans_prob_writer.write('\n')
self._transition.save(trans_prob_writer,self._str2id)
s = name + '.emit'
emit_prob_writer = file(s, 'w')
self._emission.save(emit_prob_writer, self._str2id)
def testTraining():
hmm = Hmm()
hmm.loadProbs('./phone/phone-init1')
input_reader = file('./phone/phone.train', 'r')
sequences = list()
hmm.readSeqs(input_reader,sequences)
hmm.baumWelch(sequences, 10)
hmm.saveProbs('./phone/rphone-init1')
def testPrediction():
hmm = Hmm()
hmm.loadProbs('./phone/pos/pos')
input_reader = file('./phone/pos/phone.train', 'r')
#hmm.loadProbs('./phone/phone-init1')
#input_reader = file('./phone/phone.train', 'r')
sequences = list()
hmm.readSeqs(input_reader,sequences)
for i in range(0, len(sequences)):
seq = sequences[i]
for j in range(0, len(seq)):
hmm.addObservation(seq[j])
path = list()
joint_prob = hmm.viterbi(path)
print('P(path)=%f ' %exp(joint_prob - hmm.obsProb() ))
print('path: ')
for j in range(0, len(path)):
trans = path[j]
if(trans == 0):
continue
print('%s \t %s' %(hmm.getStr(trans._obs) , hmm.getStr(trans._to.state())))
hmm.reset()
if __name__ == "__main__":
testPrediction()
testTraining()