Part 3: 实现拼写纠错
此项目需要的数据:
- vocab.txt: 这是一个词典文件,作为判断单词是否拼错的依据,任何未出现在词典中的词都认为拼写错误。
- spell-errors.txt: 该文件记录了很多用户写错的单词和对应正确的单词,可以通过该文件确定每个正确的单词所对应的错误拼写方式,并计算出每个错误拼写方式出现的概率
- testdata.txt: 记录了一些包含拼写错误的单词的文档,用于最后测试
流程:
- 找出拼写错误的单词,不存在于词典中的单词都认为拼写错误
- 生成与错误单词编辑距离不大于2的候选单词,过滤掉不在词典中的单词
- 根据贝叶斯公示选择最合适的单词
Part 3.1 加载词典文件,根据错误单词,生成候选单词集合
vocab = set([line.strip() for line in open('vocab.txt')])
def generate_candinates(wrong_word):
"""
word: 给定的输入(错误的输入)
返回所有(valid)候选集合
"""
# 生成编辑距离为1的单词
# 1.insert 2. delete 3. replace
# appl: replace: bppl, cppl, aapl, abpl...
# insert: bappl, cappl, abppl, acppl....
# delete: ppl, apl, app
letters = 'abcdefghijklmnopqrstuvwxyz'
splits = [(wrong_word[:i], wrong_word[i:]) for i in range(len(wrong_word) + 1)]
inserts = [left + letter + right for left, right in splits for letter in letters]
deletes = [left + right[1:] for left, right in splits]
replaces = [left + letter + right[1:] for left, right in splits for letter in letters]
candidates = set(inserts + deletes + replaces)
# 过滤掉不存在于词典库里面的单词
return [candi for candi in candidates if candi in vocab]
# 生成编辑距离为2的单词
def generate_edit_two(wrong_word):
def generate_edit_one(wrong_word):
letters = 'abcdefghijklmnopqrstuvwxyz'
splits = [(wrong_word[:i], wrong_word[i:]) for i in range(len(wrong_word) + 1)]
inserts = [left + letter + right for left, right in splits for letter in letters]
deletes = [left + right[1:] for left, right in splits]
replaces = [left + letter + right[1:] for left, right in splits for letter in letters]
return set(inserts + deletes + replaces)
candi_one = generate_edit_one(wrong_word)
candi_list = []
for candi in candi_one:
candi_list.extend(generate_edit_one(candi))
candi_two = set(candi_list)
return [candi for candi in candi_two if candi in vocab]
Part 3.2 加载拼写错误的文件,统计正确单词被拼写成不同错误单词的次数
misspell_prob = {}
for line in open('spell-errors.txt'):
items = line.split(':')
correct = items[0].strip()
misspells = [item.strip() for item in items[1].split(',')]
misspell_prob[correct] = {}
for misspell in misspells:
misspell_prob[correct][misspell] = 1 / len(misspells)
Part 3.3 加载语料库,统计正确单词出现在一句话中的次数,使用Bigram语言模型,只考虑一个单词和前后一个单词的关系
from nltk.corpus import reuters
# 读取语料库
categories = reuters.categories()
corpus = reuters.sents(categories=categories)
# 构建语言模型: bigram
term_count = {}
biagram_term_count = {}
for doc in corpus:
doc = ['<s>']+doc
for i in range(len(doc)-1):
term = doc[i]
biagram_term = doc[i:i+2]
biagram_term = ' '.join(biagram_term)
if term in term_count:
term_count[term] += 1
else:
term_count[term] = 1
if biagram_term in biagram_term_count:
biagram_term_count[biagram_term] += 1
else:
biagram_term_count[biagram_term] = 1
Part 3.4 加载测试数据,找出拼写错误的单词,生成候选词并计算每个候选词的概率,找出概率最大的候选词作为正确单词
import numpy as np
V = len(term_count)
with open('testdata.txt') as file:
for line in file:
items = line.split('\t')
word_list = items[2].split()
# word_list = ["I", "like", "playing"]
for index, word in enumerate(word_list):
word = word.strip(',.')
if word not in vocab:
candidates = generate_candinates(word)
if len(candidates) == 0:
candidates = generate_edit_two(word)
probs = []
prob_dict = {}
# 对于每一个candidate, 计算它的prob
# prob = p(correct)*p(mistake|correct)
# = log p(correct) + log p(mistake|correct)
# 返回prob最大的candidate
for candi in candidates:
prob = 0
# a. 计算log p(mistake|correct)
if candi in misspell_prob and word in misspell_prob[candi]:
prob += np.log(misspell_prob[candi][word])
else:
prob += np.log(0.0001)
# b. log p(correct),计算计算过程中使用了Add-one Smoothing的平滑操作
# 先计算log p(word|pre_word)
pre_word = word_list[index-1] if index > 0 else '<s>'
biagram_pre = ' '.join([pre_word, word])
if pre_word in term_count and biagram_pre in biagram_term_count:
prob += np.log((biagram_term_count[biagram_pre]+1)/(term_count[pre_word]+V))
elif pre_word in term_count:
prob += np.log(1/(term_count[pre_word]+V))
else:
prob += np.log(1/V)
# 再计算log p(next_word|word)
if index+1 < len(word_list):
next_word = word_list[index + 1]
biagram_next = ' '.join([word, next_word])
if word in term_count and biagram_next in biagram_term_count:
prob += np.log((biagram_term_count[biagram_next]+1)/(term_count[word]+V))
elif word in term_count:
prob += np.log(1/(term_count[word]+V))
else:
prob += np.log(1/V)
probs.append(prob)
prob_dict[candi] = prob
if probs:
max_idx = probs.index(max(probs))
print(word, candidates[max_idx])
print(prob_dict)
else:
print(word, False)