Katz平滑的实现

Katz平滑的讲解可参见:https://zhuanlan.zhihu.com/p/100256789。本文是实现了该文章的示例。具体的Katz平滑有时间会把自己的理解写一下。看代码可能会有助于理解文章的内容。

import tensorflow as tf
import collections
class Katz():
    def __init__(self):
        self.words_uni = collections.defaultdict(int)
        self.words_bi = collections.defaultdict(int)
        self.ml_uni=collections.defaultdict(float)
        self.nr_bi=collections.defaultdict(int)
        self.rhat_bi = collections.defaultdict(float)
        self.r_bi = collections.defaultdict(float) #ml_bi
        self.dr_bi = collections.defaultdict(float)
        self.ml_bi_discount = collections.defaultdict(float)
        self.head_bi = collections.defaultdict(float)
        self.tail_bi = collections.defaultdict(float)
        self.bow_dict = collections.defaultdict(float)
        self.T = 3
        self.A = 0

    def replace_end(self,input,pattern,rewrite):
        return tf.strings.regex_replace(input,pattern,rewrite)

    def split_ngram(self,input,width):
        words = tf.strings.split(input)
        return tf.strings.ngrams(words,width)

    def sen_bi(self,sen):
        bi = self.split_ngram(sen,2)
        p=1.
        for i in bi:
            words = tf.strings.split(i)[0]
            str_index = words.numpy().decode()
            p = p * self.words_bi[i.numpy().decode()] / self.words_uni[str_index]
        #p = p * self.words_bi_tail[tail]/self.words_uni[tail]
        return p

    def add_head_tail(self,s):
        head_tail=[]
        for i in s:
            head_tail.append("<s> "+i+" </s>")
        return head_tail

    def example(self):
        data=["dogs chase cats",
              "dogs bark",
              "cats meow",
              "dogs chase birds",
              "cats chase birds",
              "dogs chase the cats",
              "the birds chirp"
              ]
        self.test_data = [
            "cats meow",
            "dogs chase the birds",
            "birds chirp",
            "Wang dogs Zhi Guo"
        ]
        self.data_sep = self.add_head_tail(data)
        self.test_data = self.add_head_tail(self.test_data)

    def unigram(self):
        data=tf.strings.split(self.data_sep)
        self.words_uni= collections.defaultdict(int)
        for item in data:
            for i in item:
                if i != "":
                    freq = tf.strings.split(i)
                    for index in freq:
                        words = index.numpy().decode()
                        self.words_uni[words] += 1

    def bigram(self):
        self.total_bi_sentence = 0
        self.words_bi = collections.defaultdict(int)
        for item in self.data_sep:
            self.total_bi_sentence += 1
            bigram_index = self.split_ngram(item, 2)
            for index in bigram_index:
                str_index = index.numpy().decode()
                self.words_bi[str_index] += 1

    def log10(self,data):
        return tf.math.log(data)/tf.math.log(10.)

    def unigram_ml(self):
        total_unigram = sum(self.words_uni.values())- self.words_uni['<s>']

        for key in self.words_uni:
            self.ml_uni[key]=self.words_uni[key]/total_unigram

    def bigram_ml(self):
        for key in self.words_bi:
            head = key.split()
            k0 = head[0]
            k1 = head[1]
            self.r_bi[key] = self.words_bi[key]/self.words_uni[k0]

    def bigram_nr(self):
        values = self.words_bi.values()
        for v in values:
            self.nr_bi[v]+=1
    def bigram_A(self):
        self.A = (self.T+1.)*self.nr_bi[self.T+1]/self.nr_bi[1]

    def bigram_discount(self):
        key = self.words_bi.keys()
        for k in key:
            head = k.split()
            k0=head[0]
            k1 = head[1]
            n = self.words_bi[k]
            if n<self.T:
                self.rhat_bi[k]=(n+1)*self.nr_bi[n+1]/self.nr_bi[n]
                self.dr_bi[k]=(self.rhat_bi[k]/n-self.A)/(1-self.A)
                self.ml_bi_discount[k]=self.r_bi[k]*self.dr_bi[k]
                self.head_bi[k0]+=self.ml_bi_discount[k]
                self.tail_bi[k0]+=self.ml_uni[k1]
            else:
                self.ml_bi_discount[k]=self.r_bi[k]
                self.head_bi[k0]+=self.ml_bi_discount[k]
                self.tail_bi[k0]+=self.ml_uni[k1]

    def bow(self):
        for key in self.words_uni:
            if key !="</s>":
                self.bow_dict[key]=(1-self.head_bi[key])/(1-self.tail_bi[key])
        self.bow_dict["<unk>"] = 0.0

    def replace_nuk(self,input):
        words = input.split()
        for i in range(len(words)):
            if not self.words_uni.get(words[i]):
                words[i] = "<unk>"
        return " ".join(words)

    def test(self):
        for item in self.test_data:
            words = self.replace_nuk(item)
            bigram_index = self.split_ngram(words, 2)
            p = 0.
            for index in bigram_index:
                index = index.numpy().decode()
                if index == "<s> <unk>":
                    continue
                if index == "<unk> <unk>":
                    continue
                if index == "<unk> </s>":
                    p = p+self.log10(self.ml_uni["</s>"])
                    continue

                value = self.ml_bi_discount.get(index)
                if value :
                    p += self.log10(value)
                else:
                    head = index.split()
                    k0 = head[0]
                    k1 = head[1]
                    if k0 == "<unk>":
                        continue
                    if k1 == "<unk>":
                        p = p+self.log10(self.ml_uni[k0])
                        continue
                    bow_value = self.bow_dict.get(k0)
                    if bow_value :
                        p += self.log10(bow_value)
                    uni_value = self.ml_uni[k1]
                    if uni_value :
                        p += self.log10(uni_value)
            print(item,":",p.numpy())

k=Katz()
k.example()
k.unigram()
k.unigram_ml()
k.bigram()
k.bigram_ml()
k.bigram_nr()
k.bigram_A()
k.bigram_discount()
k.bow()
k.test()


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值