手痒,自己实现了一下,UT已经通过。在lucene4基础上实现,加上接口不到300行代码。
package com.dp.junhao.jhsegmenter;
import gnu.trove.iterator.TByteIterator;
import gnu.trove.list.array.TByteArrayList;
import gnu.trove.procedure.TByteProcedure;
import gnu.trove.set.TByteSet;
import gnu.trove.set.hash.TByteHashSet;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.UnicodeUtil;
import java.nio.charset.Charset;
import java.util.*;
/**
* Created by junhao.zhang on 17/7/27.
*/
public class DAT {
private final int[] base;
private final int[] check;
// base数组默认填为0,check数组默认填为-1
// 看节点是否存在,只看check[i]==-1
// 是否是tail节点,看check[i]最高位是否为0
public DAT(int[] base, int[] check) {
this.base = base;
this.check = check;
}
// for test.
int[] getBaseArray() { return base; }
// for test.
int[] getCheckArray() { return check; }
public boolean containsNode(String term) {
BytesRef bytes = new BytesRef();
UnicodeUtil.UTF16toUTF8(term, 0, term.length(), bytes);
return containsNode(bytes);
}
public boolean containsNode(byte[] bytes, int offset, int length) {
int node = 0, parent;
for (int i = offset; i < offset + length; ++i) {
parent = node;
node = base[node] + bytes[i];
if (node >= check.length || (check[node] & 0x7fffffff) != parent) {
return false;
}
}
return true;
}
public boolean containsNode(BytesRef bytes) {
return containsNode(bytes.bytes, bytes.offset, bytes.length);
}
public boolean containsTerm(String term) {
BytesRef bytes = new BytesRef();
UnicodeUtil.UTF16toUTF8(term, 0, term.length(), bytes);
return containsTerm(bytes);
}
public boolean containsTerm(BytesRef bytes) {
return containsTerm(bytes.bytes, bytes.offset, bytes.length);
}
public boolean containsTerm(byte[] bytes, int offset, int length) {
int node = 0, parent;
for (int i = offset; i < offset + length; ++i) {
parent = node;
node = base[node] + bytes[i];
if (node >= check.length || (check[node] & 0x7fffffff) != parent) {
return false;
}
}
return check[node] >= 0;
}
public double getCompactRate() {
int count = 0;
for (int e : check) {
if (e != -1) {
++count;
}
}
return (double) count / check.length;
}
public static class Builder {
private static final BytesRef EMPTY_BYTESREF = new BytesRef(BytesRef.EMPTY_BYTES, 0, 0);
private static final int INITIAL_CAPACITY = 32;
private static final int[] table2pow = new int[]{0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80,
0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000,
0x10000, 0x20000, 0x40000, 0x80000, 0x100000, 0x200000, 0x400000, 0x800000,
0x1000000, 0x2000000, 0x4000000, 0x8000000, 0x10000000, 0x20000000, 0x40000000}; // 最高位留作词标志位.
private TreeSet<BytesRef> terms = new TreeSet<BytesRef>();
public Builder addTerm(String term) {
if (term.isEmpty()) {
return this;
}
BytesRef bytesRef = new BytesRef();
UnicodeUtil.UTF16toUTF8(term, 0, term.length(), bytesRef);
terms.add(bytesRef);
return this;
}
private static BytesRef nextBytesRef(BytesRef bytes) {
BytesRef newBytes = new BytesRef();
newBytes.copyBytes(bytes);
int advance = 1;
for (int i = newBytes.length - 1; i >= 0; --i) {
if ((newBytes.bytes[i] & 0xff) + advance <= 255) {
newBytes.bytes[i] += advance;
break;
}
newBytes.bytes[i] = 0;
advance = 1;
}
return newBytes;
}
private static boolean isTopmost(BytesRef bytes) {
for (int i = 0; i < bytes.length; ++i) {
if (bytes.bytes[bytes.offset + i] < 255) {
return false;
}
}
return true;
}
private static int getBestBaseValue(AutoExpandIntArray check, int parentNode, TByteArrayList list) {
for (int i = parentNode + 1 - list.get(0); ;++i) {
TByteIterator itr = list.iterator();
int countdown = list.size();
while (itr.hasNext()) {
if (check.get(i + itr.next(), -1) != -1) {
break;
}
--countdown;
}
if (countdown == 0) {
return i;
}
}
}
private static int getProper2PowNum(int minSize) {
if (minSize > table2pow[table2pow.length - 1]) {
throw new IllegalArgumentException(String.format("Array size: %d too large", minSize));
}
int low = 0, high = table2pow.length - 1;
while (low <= high) {
int mid = (low + high) / 2;
if (minSize == table2pow[mid]) {
return table2pow[mid];
} else if (minSize < table2pow[mid]) {
high = mid - 1;
} else {
low = mid + 1;
}
}
return table2pow[low];
}
private SortedSet<BytesRef> getTermsWithPrefix(BytesRef bytes) {
if (isTopmost(bytes)) {
return terms.tailSet(bytes, false);
}
return terms.subSet(bytes, false, nextBytesRef(bytes), false);
}
private static class Entry {
BytesRef bytes;
int node;
Entry(BytesRef bytes, int node) {
this.bytes = bytes;
this.node = node;
}
}
private static class AutoExpandIntArray {
int[] data;
AutoExpandIntArray(int initialSize, int fillValue) {
data = new int[initialSize];
if (fillValue != 0) {
Arrays.fill(data, fillValue);
}
}
void expand(int minSize, int fillValue) {
int[] newData = new int[getProper2PowNum(minSize)];
System.arraycopy(data, 0, newData, 0, data.length);
if (fillValue != 0) {
Arrays.fill(newData, data.length, newData.length, fillValue);
}
data = newData;
}
int set(int pos, int value, int fillValue) {
if (pos >= data.length) {
expand(pos + 1, fillValue);
}
return data[pos] = value;
}
int get(int pos, int defaultValue) {
if (pos >= data.length) {
expand(pos + 1, defaultValue);
}
return data[pos];
}
}
public DAT build() {
// BFS to iterate all nodes.
final Queue<Entry> queue = new LinkedList<Entry>();
final AutoExpandIntArray base = new AutoExpandIntArray(INITIAL_CAPACITY, 0); // TODO: calc proper capacity.
final AutoExpandIntArray check = new AutoExpandIntArray(INITIAL_CAPACITY, -1);
queue.add(new Entry(EMPTY_BYTESREF, 0));
while (!queue.isEmpty()) {
final Entry elem = queue.poll();
SortedSet<BytesRef> tailSet = getTermsWithPrefix(elem.bytes);
final int length = elem.bytes.length;
Iterator<BytesRef> itr = tailSet.iterator();
TByteArrayList siblings = new TByteArrayList();
final TByteSet termSet = new TByteHashSet();
while (itr.hasNext()) {
BytesRef term = itr.next();
byte b = term.bytes[length];
if (term.length == length + 1) {
termSet.add(b);
}
if (!siblings.isEmpty() && siblings.get(siblings.size() - 1) == b) {
continue;
}
siblings.add(b);
}
if (siblings.isEmpty()) {
continue;
}
final int baseValue = base.set(elem.node, getBestBaseValue(check, elem.node, siblings), 0);
final int parentNode = elem.node;
siblings.forEach(new TByteProcedure() {
@Override
public boolean execute(byte b) {
byte[] newBytes = new byte[length + 1];
System.arraycopy(elem.bytes.bytes, elem.bytes.offset, newBytes, 0, length);
newBytes[length] = b;
queue.add(new Entry(new BytesRef(newBytes), baseValue + b));
if (termSet.contains(b)) {
check.set(baseValue + b, parentNode, -1);
} else {
check.set(baseValue + b, (parentNode | 0x80000000), -1);
}
return true;
}
});
}
return new DAT(base.data, check.data);
}
}
}
UT部分:
package com.dp.junhao.jhsegmenter;
import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.UnicodeUtil;
import org.testng.annotations.Test;
import static junit.framework.Assert.assertEquals;
import static junit.framework.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
/**
* Created by junhao.zhang on 17/8/10.
*/
public class DATTest {
@Test
public void testSimpleDAT() {
DAT.Builder builder = new DAT.Builder();
builder.addTerm("ABC").addTerm("AC").addTerm("ACE")
.addTerm("ACFF").addTerm("AD").addTerm("BBC").addTerm("CD").addTerm("CF").addTerm("ZQ");
DAT dat = builder.build();
assertTrue(dat.containsTerm("ABC"));
assertTrue(dat.containsTerm("AC"));
assertTrue(dat.containsTerm("ACE"));
assertTrue(dat.containsTerm("ACFF"));
assertTrue(dat.containsTerm("AD"));
assertTrue(dat.containsTerm("BBC"));
assertTrue(dat.containsTerm("CD"));
assertTrue(dat.containsTerm("CF"));
assertTrue(dat.containsTerm("ZQ"));
assertFalse(dat.containsTerm("BB"));
}
@Test
public void testSimpleDATWhiteBox() {
DAT.Builder builder = new DAT.Builder();
builder.addTerm("ABC").addTerm("AC").addTerm("ACE")
.addTerm("ACFF").addTerm("AD").addTerm("BBC").addTerm("CD").addTerm("CF").addTerm("ZQ");
DAT dat = builder.build();
int[] base = dat.getBaseArray();
int[] check = dat.getCheckArray();
assertEquals(base.length, 32);
assertEquals(check.length, 32);
assertEquals(base[0], -64);
assertEquals(base[1], -62);
assertEquals(base[2], -59);
assertEquals(base[3], -60);
assertEquals(base[4], -58);
assertEquals(base[5], -58);
assertEquals(base[6], 0);
assertEquals(base[7], -54);
assertEquals(base[8], 0);
assertEquals(base[9], 0);
assertEquals(base[10], 0);
assertEquals(base[11], 0);
assertEquals(base[12], -56);
assertEquals(base[13], 0);
assertEquals(base[14], 0);
assertEquals(base[15], 0);
assertEquals(base[16], 0);
assertEquals(base[17], 0);
assertEquals(base[18], 0);
assertEquals(base[19], 0);
assertEquals(base[20], 0);
assertEquals(base[21], 0);
assertEquals(base[22], 0);
assertEquals(base[23], 0);
assertEquals(base[24], 0);
assertEquals(base[25], 0);
assertEquals(base[26], -54);
assertEquals(base[27], 0);
assertEquals(base[28], 0);
assertEquals(base[29], 0);
assertEquals(base[30], 0);
assertEquals(base[31], 0);
assertEquals(check[0], -1);
assertEquals(check[1], (0 | 0x80000000));
assertEquals(check[2], (0 | 0x80000000));
assertEquals(check[3], (0 | 0x80000000));
assertEquals(check[4], (1 | 0x80000000));
assertEquals(check[5], 1); // ABC
assertEquals(check[6], 1); // AD
assertEquals(check[7], (2 | 0x80000000));
assertEquals(check[8], 3); // CD
assertEquals(check[9], 4); // ABC
assertEquals(check[10], 3); // CF
assertEquals(check[11], 5); // ACE
assertEquals(check[12], (5 | 0x80000000));
assertEquals(check[13], 7); // BBC
assertEquals(check[14], 12); // ACFF
assertEquals(check[15], -1);
assertEquals(check[16], -1);
assertEquals(check[17], -1);
assertEquals(check[18], -1);
assertEquals(check[19], -1);
assertEquals(check[20], -1);
assertEquals(check[21], -1);
assertEquals(check[22], -1);
assertEquals(check[23], -1);
assertEquals(check[24], -1);
assertEquals(check[25], -1);
assertEquals(check[26], (0 | 0x80000000));
assertEquals(check[27], 26); // ZQ
assertEquals(check[28], -1);
assertEquals(check[29], -1);
assertEquals(check[30], -1);
assertEquals(check[31], -1);
assertEquals(dat.getCompactRate(), 0.5);
}
}