1.1、BM25算法
公式:
从BM25算法公式可以看出,两篇长度相同的文档,如果检索词的如果检索词中每个term的idf也分别相同,则越高相似度得分就越高。
2、lucene当前BM25相似度计算逻辑
本章节先对和算法最相关的类进行介绍,然后用实际案例介绍相似度算法的计算流程,最后介绍explain的实现。
2.0、Lucene相似度接口
org.apache.lucene.search.similarities.Similarity
其中有个内部类:org.apache.lucene.search.similarities.Similarity.SimScorer
这个类提供两个方法,一个用于计算相似度得分,另一个用来解释相似度得分
public abstract float score(float freq, long norm);
freq:词频
norm: 标准化因子的编码值
返回:得分
public Explanation explain(Explanation freq, long norm) {
return Explanation.match(
score(freq.getValue().floatValue(), norm),
"score(freq=" + freq.getValue() +"), with freq of:",
Collections.singleton(freq));
}
2.1、相似度接口实现
org.apache.lucene.search.similarities.BM25Similarity是相似度接口的实现类,定义了k1、b, idf, avgFieldLength(相当于公式中的avgd) 等公式中的一些常量或变量的计算方法。同时定义了idfExplain方法, 用于解释idf的计算。还定义了scorer方法,用于构造相似度计算的scorer。
BM25Similarity真正用的公式为:
依据:
org.apache.lucene.search.similarities.BM25Similarity.BM25Scorer#BM25Scorer中this.weight = boost * idf.getValue().floatValue();
BM25Scorer(float boost, float k1, float b, Explanation idf, float avgdl, float[] cache) {
this.boost = boost;
this.idf = idf;
this.avgdl = avgdl;
this.k1 = k1;
this.b = b;
this.cache = cache;
this.weight = boost * idf.getValue().floatValue();
}
lucene BM25Similarity实际的公式为:
依据:org.apache.lucene.search.similarities.BM25Similarity.BM25Scorer#score
@Override
public float score(float freq, long encodedNorm) {
double norm = cache[((byte) encodedNorm) & 0xFF];
return weight * (float) (freq / (freq + norm));
}
其中K1是常量,默认1.2,b是常量,默认0.75。
所以可以得出两个结论:
- BM25Scorer实现的是一个term的,而不负责最后的把所有term的进行加和计算。
2、BM25Scorer对于term的计算,没有乘以(k1+1), 但Lucene还提供了一个org.apache.lucene.search.similarity.LegacyBM25Similarity类,对BM25Similarity进行了包装,它把R中的(K1+1)挪到了W中, 所以LegacyBM25Similarity的BM25相似度算法和1.1的公式是等价的。
依据:org.apache.lucene.search.similarity.LegacyBM25Similarity#scorer
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return bm25Similarity.scorer(boost * (1 + bm25Similarity.getK1()), collectionStats, termStats);
}
介绍一下关键参数的来源或计算:
计算idf
公式:log(1 + (docCount - docFreq + 0.5)/(docFreq + 0.5))
docCount为文档总数,实际是一个字段的文档总数,docFreq为包含该语素的文档数,也是一个字段内
依据:org.apache.lucene.search.similarities.BM25Similarity#idf
protected float idf(long docFreq, long docCount) {
return (float) Math.log(1 + (docCount - docFreq + 0.5D)/(docFreq + 0.5D));
}
org.apache.lucene.search.similarities.BM25Similarity#scorer中构造BM25Scorer时会计算idf, docFreq来自termStats.docFreq, docCount来自collectionStats.docCount。
public final SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
Explanation idf = termStats.length == 1 ? idfExplain(collectionStats, termStats[0]) : idfExplain(collectionStats, termStats);
float avgdl = avgFieldLength(collectionStats);
float[] cache = new float[256];
for (int i = 0; i < cache.length; i++) {
cache[i] = k1 * ((1 - b) + b * LENGTH_TABLE[i] / avgdl);
}
if (!supportTp) {
return new BM25Scorer(boost, k1, b, idf, avgdl, cache);
} else {
return new BM25TPScorer(boost, k1, b, idf, avgdl, cache);
}
}
public Explanation idfExplain(CollectionStatistics collectionStats, TermStatistics termStats) {
final long df = termStats.docFreq();
final long docCount = collectionStats.docCount();
final float idf = idf(df, docCount);
return Explanation.match(idf, "idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:",
Explanation.match(df, "n, number of documents containing term"),
Explanation.match(docCount, "N, total number of documents with field"));
}
collectionStats和termStats是构造TermWeight时创建并传给BM25Scorer的。
计算文档平均长度
公式:sumTotalTermFreq / docCount
依据:org.apache.lucene.search.similarities.BM25Similarity#avgFieldLength
protected float avgFieldLength(CollectionStatistics collectionStats) {
return (float) (collectionStats.sumTotalTermFreq() / (double) collectionStats.docCount());
}
获取当前文档长度调节因子
从org.apache.lucene.search.similarities.BM25Similarity.BM25Scorer#score里发现
来自于cache[]
cache[]是构造org.apache.lucene.search.similarities.BM25Similarity.BM25Scorer#BM25Scorer时传进来的
public final SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
Explanation idf = termStats.length == 1 ? idfExplain(collectionStats, termStats[0]) : idfExplain(collectionStats, termStats);
float avgdl = avgFieldLength(collectionStats);
float[] cache = new float[256];
for (int i = 0; i < cache.length; i++) {
cache[i] = k1 * ((1 - b) + b * LENGTH_TABLE[i] / avgdl);
}
return new BM25Scorer(boost, k1, b, idf, avgdl, cache);
}
cache的作用是缓存不同长度对应的。
获取词频
TermScorer是一个字段的一个term的打分器,这个打分器包含leafSimScorer和postingsEnum, LeafSimScorer是BM25Scorer的包装,postingsEnum的freq()方法可以获取一个term在一个字段的词频
2.2、求和实现
BM25Scorer.score的累加发生在org.apache.lucene.search.WANDScorer#score方法中,WANDScorer.score迭代每个scorer, 把得分相加。从而实现了1.1的
public float score() throws IOException {
// we need to know about all matches
advanceAllTail();
double score = 0;
for (DisiWrapper s = lead; s != null; s = s.next) {
score += s.scorer.score();
}
return (float) score;
}
2.3、相似度计算流程
2.3.1、借助elasticsearch创建两篇文档
[
{
"_index":"bm25index",
"_type":"_doc",
"_id":"aBURUn0BIDdwD_i4XRGW",
"_score":1,
"_source":{
"field1":"Tom has two children named Kate and Jerry"
}
},
{
"_index":"bm25index",
"_type":"_doc",
"_id":"aRUSUn0BIDdwD_i4EhHd",
"_score":1,
"_source":{
"field1":"Tom and Jerry is a classic comedy cartoon"
}
}
]
- es query dsl:
POST http://localhost:9200/bm25index/_search
{
"query":{
"match":{
"field1":"Tom and Jerry"
}
}
}
2.3.2、生成Query, org.elasticsearch.search.SearchService#createAndPutContext这里会调用SearchContext context = createContext(request); 把request中的 "field1": "Tom and Jerry"转为field1:Tom OR field1:and OR field1:Jerry的BooleanQuery, 这个BooleanQuery会传给后续的检索接口。
2.3.3、Lucene的检索接口:org.apache.lucene.search.IndexSearcher#search(org.apache.lucene.search.Query, org.apache.lucene.search.Collector)
此方法会根据query, 调用query.createWeight创建Weight树, Weight节点会构造Scorer并提供score方法进行打分。上述dsl对应的Query为BooleanQuery, 包含3个clause, 分别为:field1:Tom、 field1:and以及field1:Jerry,遍历clauses, 分别为3个clause创建子Weight, 详见org.apache.lucene.search.BooleanWeight#BooleanWeight
由于本例中,3个clause.query都是TermQuery, 因此由他们创建的Weight都是TermWeight, org.apache.lucene.search.TermQuery.TermWeight#TermWeight
构造TermWeight的过程中会调用similarity.scorer来创建simScorer,用于相似度的打分,本例的similarity是BM25Similarity,所以创建出来的simScorer是BM25Scorer,关于BM25Similarity和BM25Scorer的介绍参考2。TermWeight.scorer方法中先使用simScorer创建LeafSimScorer,然后使用LeafSimScorer创建TermScorer,过程详见org.apache.lucene.search.TermQuery.TermWeight#scorer:
public Scorer scorer(LeafReaderContext context) throws IOException {
assert termStates == null || termStates.wasBuiltFor(ReaderUtil.getTopLevelContext(context)) : "The top-reader used to create Weight is not the same as the current reader's top-reader (" + ReaderUtil.getTopLevelContext(context);;
final TermsEnum termsEnum = getTermsEnum(context);
if (termsEnum == null) {
return null;
}
LeafSimScorer scorer = new LeafSimScorer(simScorer, context.reader(), term.field(), scoreMode.needsScores());
if (scoreMode == ScoreMode.TOP_SCORES) {
return new TermScorer(this, termsEnum.impacts(PostingsEnum.FREQS), scorer);
}
} else {
return new TermScorer(this, termsEnum.postings(null, scoreMode.needsScores() ? PostingsEnum.FREQS : PostingsEnum.NONE), scorer);
}
}
从TermWeight.scorer的代码看出,构造TermScorer时,会先调用termsEnum.impacts创建ImpactsEnum或调用termsEnum.postings来创建PostingEnum, ImpactsEnum是PostingEnum的子类,PostingEnum是postings的迭代器, PostingsEnum.nextPosition方法用于迭代获取term在一个文档中的位置,PostingsEnum.freq方法可获取term在一个文档中的词频.
构造完Weight树后,会进入org.elasticsearch.search.internal.ContextIndexSearcher#searchInternal调用weight.bulkScorer来创建BulkScorer, BulkScorer用于给一定范围内的多篇文档打分。
在创建BulkScorer的过程中会调用org.apache.lucene.search.Boolean2ScorerSupplier#opt来创建WANDScorer
WANDScorer会把创建weight过程中创建出来的3个BM25Scorer组织起来,org.apache.lucene.search.WANDScorer#score方法会把3个scorer.score方法返回的得分的累加,从而完成BM25相似度算法的计算.
2.4、explain实现
BM25Similarity$BM25Scorer实现了explain方法,这个方法主要对tf、idf、 boost以及计算tf用到的k1、b、dl和avgdl做出解释。
public Explanation explain(Explanation freq, long encodedNorm) {
List<Explanation> subs = new ArrayList<>(explainConstantFactors());
Explanation tfExpl = explainTF(freq, encodedNorm);
subs.add(tfExpl);
return Explanation.match(weight * tfExpl.getValue().floatValue(),
"score(freq="+freq.getValue()+"), product of:", subs);
}
参数freq是一个针对词频做出解释的Explanation, 把它传入explainTF中构造tf Explanation。
protected Explanation explainTF(Explanation freq, long norm) {
List<Explanation> subs = new ArrayList<>();
subs.add(freq);
subs.add(Explanation.match(k1, "k1, term saturation parameter"));
float doclen = LENGTH_TABLE[((byte) norm) & 0xff];
subs.add(Explanation.match(b, "b, length normalization parameter"));
if ((norm & 0xFF) > 39) {
subs.add(Explanation.match(doclen, "dl, length of field (approximate)"));
} else {
subs.add(Explanation.match(doclen, "dl, length of field"));
}
subs.add(Explanation.match(avgdl, "avgdl, average length of field"));
float normValue = k1 * ((1 - b) + b * doclen / avgdl);
return Explanation.match(
(float) (freq.getValue().floatValue() / (freq.getValue().floatValue() + (double) normValue)),
"tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:", subs);
}