统计分词法 统计分词原理与实战

class HMM(object):

    def __init__(self):

        self.state_list = ['B','M','E','S']

        self.start_p = {}

        self.trans_p = {}

        self.emit_p = {}

        self.model_file = 'hmm_model.pkl'

        self.trained = False

    def train(self,datas,model_path=None):

        if model_path == None:

            model_path = self.model_file

        #统计状态频数

        state_dict = {}

        def init_parameters():

            for state in self.state_list:

                self.start_p[state] = 0.0

                self.trans_p[state] = {s:0.0 for s in self.state_list}

                self.emit_p[state] = {}

                state_dict[state] = 0

        def make_label(text):

            out_text = []

            if len(text) == 1:

                out_text = ['S']

            else :

                out_text += ['B']+['M']*(len(text)-2)+['E']

            return out_text

        init_parameters()

        line_nb = 0

        #监督学习方法求解参数

        for line in datas:

            line = line.strip()

            if not line:

                continue

            line_nb += 1

            word_list = [w for w in line if w != ' ']

            line_list = line.split()

            line_state = []

            for w in line_list:

                line_state.extend(make_label(w))

            assert len(line_state) == len(word_list)

            for i,v in enumerate(line_state):

                state_dict[v] += 1

                if i == 0:

                    self.start_p[v] += 1

                else :

                    self.trans_p[line_state[i-1]][v] += 1

                    self.emit_p[line_state[i]][word_list[i]] = self.emit_p[line_state[i]].get(word_list[i],0)+1.0

        self.start_p = {k: v*1.0/line_nb for k,v in self.start_p.items()}

        self.trans_p = {k:{k1: v1/state_dict[k1] for k1,v1 in v0.items()} for k,v0 in self.trans_p.items()}

        self.emit_p = {k:{k1: (v1+1)/state_dict.get(k1,1.0) for k1,v1 in v0.items()} for k,v0 in self.emit_p.items()}

        with open(model_path,'wb') as f:

            import pickle

            pickle.dump(self.start_p,f)

            pickle.dump(self.trans_p,f)

            pickle.dump(self.emit_p,f)

        self.trained = True

        print('model train done,parameters save to ',model_path)

    #读取参数模型

    def load_model(self,path):

        import pickle

        with open(path,'rb') as f:

            self.start_p = pickle.load(f)

            self.trans_p = pickle.load(f)

            self.emit_p = pickle.load(f)

        self.trained = True

        print('model parameters load done!')

    #维特比算法求解最优路径 

    def __viterbi(self,text,states,start_p,trans_p,emit_p):

        V = [{}]

        path = {}

        for y in states:

            V[0][y] = start_p[y]*emit_p[y].get(text[0],1.0)

            path[y] = [y]

        for t in range(1,len(text)):

            V.append({})

            new_path = {}

            for y in states:

                emitp = emit_p[y].get(text[t],1.0)

                (prob , state) = max([(V[t - 1][y0] * trans_p[y0].get(y, 0) * emitp, y0) \

                                      for y0 in states if V[t - 1][y0] > 0])

                V[t][y] = prob

                new_path[y] = path[state]+[y]

            path = new_path

        if emit_p['M'].get(text[-1],0) > emit_p['S'].get(text[-1],0):

            (prob,state) = max([(V[len(text)-1][y],y) for y in ('E',"M")])

        else :

            (prob,state) = max([(V[len(text)-1][y],y) for y in states])

        return (prob,path[state])


 

    def cut(self,text):

        if not self.trained:

            print('Error:please pre train or load model parameters')

            return

        prob,pos_list = self.__viterbi(text,self.state_list,self.start_p,self.trans_p,self.emit_p)

        begin_,next_ = 0,0

        #任务:完成 HMM 中文分词算法

        # ********* Begin *********# 

        for i, char in enumerate(text):

            pos = pos_list[i]

            if pos == 'B':

                begin_ = i

            elif pos == 'E':

                yield text[begin_:i+1]

                next_ = i+1

            elif pos == 'S':

                yield char

                next_ = i+1

        if next_ < len(text):

            yield text[next_:]

               

      

        # ********* Begin *********#

if __name__ == '__main__':

    text = input()

    train_data = 'pku_training.utf8'

    model_file = 'hmm_model.pkl'

    hmm = HMM()

    hmm.train(open(train_data, 'r', encoding='utf-8'), model_file)

    hmm.load_model(model_file)

    print('/'.join(hmm.cut(text)))

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值