大概就是一个很简单的拼写纠错(没考虑语法问题)的核心代码,其中很多细节也有待完善,目前根据所学大概能写出这么一个样子,特别是找候选集合扩大编辑距离那还存在速度太慢的问题,希望各位大佬指正。
import numpy as np
from nltk.corpus import reuters
#加载词典库
vocab = set([line.rstrip() for line in open('vocab.txt')])
#读取语料库
categories = reuters.categories()
corpus = reuters.sents(categories = categories)
#生成所有的候选集合
def generate_candiated(word):
"""
给定的输入(错误的输入)
返回所有(valid)候选集合
"""
#生成编辑距离为1的单词
letters = 'abcdefghijklmnopqrstuvwxyz'
splits = [(word[:i],word[i:])for i in range(len(word)+1)]
#insert操作
inserts = [L+c+R for L,R in splits for c in letters]
#delect操作
delects = [L+R[1:] for L,R in splits if R]
#replace操作
replaces = [L+c+R[1:]for L,R in splits for c in letters]
candiates = set(inserts+delects+replaces)
return [word for word in candiates if word in vocab],candiates
# 构建语言模型
def generate_LM():
term_count = {}
bigram_count = {}
for doc in corpus:
doc = ['<s>'] + doc
for i in range(0, len(doc) - 1):
term = doc[i]
bigram = doc[i:i + 2]
if term in term_count:
term_count[term] += 1
else:
term_count[term] = 1
bigram = ' '.join(bigram)
if bigram in bigram_count:
bigram_count[bigram] += 1
else:
bigram_count[bigram] = 1
return term_count,bigram_count
#用户打错的概率统计
def mis_probs():
channel_prob = {}
for line in open('spell-errors.txt'):
items = line.split(':')
correct = items[0].strip()
mistakes = [item.strip() for item in items[1].strip().split(",")]
channel_prob[correct] = {}
for mis in mistakes:
channel_prob[correct][mis] = 1.0 / len(mistakes)
return channel_prob
def main():
#生成语言模型
term_count,bigram_count = generate_LM()
#生成每个词拼写错误的概率
channel_prob = mis_probs()
V = len(term_count.keys())
file = open('testdata.txt')
for line in file:
items = line.rstrip().split('\t')
line = items[2].strip('.').split()
# line = ['I','like','you']
for word in line:
if word not in vocab:
# 需要将word替换成正确的单词
# 先找出这个正确单词的候选集合
candiates, temp_candi = generate_candiated(word)
# 一种方式:if candiates=[],那就多生成几个candiates,比如生成编辑距离为更大的
while len(candiates) < 1:
for words in temp_candi:
candidate, temp = generate_candiated(words.rstrip('\n'))
if len(candidate) > 0:
candiates += candidate
if len(candiates) > 0:
break
"""
对于每一个candiates,计算机它的score
score = log(correct) + log(mistakes|correct)
返回score最大概率的candiates
"""
probs = []
for candi in candiates:
prob = 0
# a.计算channel_probabitity
if candi in channel_prob and word in channel_prob[candi]:
prob += np.log(channel_prob[candi][word])
else:
prob += np.log(0.0001)
# b.计算语言模型概率
idx = line.index(word) + 1
s=[]
s.append(line[idx-2])
s.append(candi)
s = ' '.join(s)
if s in bigram_count and candi in term_count:
prob += (np.log(bigram_count[s])+1) / (np.log(term_count[candi])+V)
else:
prob += np.log(1.0 / V)
probs.append(prob)
max_idx = probs.index(max(probs))
print(word, candiates[max_idx])
if __name__ == '__main__':
main()