DAT的实现

手痒,自己实现了一下,UT已经通过。在lucene4基础上实现,加上接口不到300行代码。
package com.dp.junhao.jhsegmenter;

import gnu.trove.iterator.TByteIterator;
import gnu.trove.list.array.TByteArrayList;
import gnu.trove.procedure.TByteProcedure;
import gnu.trove.set.TByteSet;
import gnu.trove.set.hash.TByteHashSet;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.UnicodeUtil;

import java.nio.charset.Charset;
import java.util.*;


/**
 * Created by junhao.zhang on 17/7/27.
 */
public class DAT {
    private final int[] base;
    private final int[] check;

    // base数组默认填为0,check数组默认填为-1
    // 看节点是否存在,只看check[i]==-1
    // 是否是tail节点,看check[i]最高位是否为0
    public DAT(int[] base, int[] check) {
        this.base = base;
        this.check = check;
    }

    // for test.
    int[] getBaseArray() { return base; }

    // for test.
    int[] getCheckArray() { return check; }

    public boolean containsNode(String term) {
        BytesRef bytes = new BytesRef();
        UnicodeUtil.UTF16toUTF8(term, 0, term.length(), bytes);
        return containsNode(bytes);
    }

    public boolean containsNode(byte[] bytes, int offset, int length) {
        int node = 0, parent;
        for (int i = offset; i < offset + length; ++i) {
            parent = node;
            node = base[node] + bytes[i];
            if (node >= check.length || (check[node] & 0x7fffffff) != parent) {
                return false;
            }
        }
        return true;
    }

    public boolean containsNode(BytesRef bytes) {
        return containsNode(bytes.bytes, bytes.offset, bytes.length);
    }

    public boolean containsTerm(String term) {
        BytesRef bytes = new BytesRef();
        UnicodeUtil.UTF16toUTF8(term, 0, term.length(), bytes);
        return containsTerm(bytes);
    }

    public boolean containsTerm(BytesRef bytes) {
        return containsTerm(bytes.bytes, bytes.offset, bytes.length);
    }

    public boolean containsTerm(byte[] bytes, int offset, int length) {
        int node = 0, parent;
        for (int i = offset; i < offset + length; ++i) {
            parent = node;
            node = base[node] + bytes[i];
            if (node >= check.length || (check[node] & 0x7fffffff) != parent) {
                return false;
            }
        }
        return check[node] >= 0;
    }

    public double getCompactRate() {
        int count = 0;
        for (int e : check) {
            if (e != -1) {
                ++count;
            }
        }
        return (double) count / check.length;
    }

    public static class Builder {
        private static final BytesRef EMPTY_BYTESREF = new BytesRef(BytesRef.EMPTY_BYTES, 0, 0);
        private static final int INITIAL_CAPACITY = 32;
        private static final int[] table2pow = new int[]{0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80,
                0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000,
                0x10000, 0x20000, 0x40000, 0x80000, 0x100000, 0x200000, 0x400000, 0x800000,
                0x1000000, 0x2000000, 0x4000000, 0x8000000, 0x10000000, 0x20000000, 0x40000000};  // 最高位留作词标志位.

        private TreeSet<BytesRef> terms = new TreeSet<BytesRef>();

        public Builder addTerm(String term) {
            if (term.isEmpty()) {
                return this;
            }

            BytesRef bytesRef = new BytesRef();
            UnicodeUtil.UTF16toUTF8(term, 0, term.length(), bytesRef);
            terms.add(bytesRef);
            return this;
        }

        private static BytesRef nextBytesRef(BytesRef bytes) {
            BytesRef newBytes = new BytesRef();
            newBytes.copyBytes(bytes);
            int advance = 1;
            for (int i = newBytes.length - 1; i >= 0; --i) {
                if ((newBytes.bytes[i] & 0xff) + advance <= 255) {
                    newBytes.bytes[i] += advance;
                    break;
                }
                newBytes.bytes[i] = 0;
                advance = 1;
            }
            return newBytes;
        }

        private static boolean isTopmost(BytesRef bytes) {
            for (int i = 0; i < bytes.length; ++i) {
                if (bytes.bytes[bytes.offset + i] < 255) {
                    return false;
                }
            }
            return true;
        }

        private static int getBestBaseValue(AutoExpandIntArray check, int parentNode, TByteArrayList list) {
            for (int i = parentNode + 1 - list.get(0); ;++i) {
                TByteIterator itr = list.iterator();
                int countdown = list.size();
                while (itr.hasNext()) {
                    if (check.get(i + itr.next(), -1) != -1) {
                        break;
                    }
                    --countdown;
                }
                if (countdown == 0) {
                    return i;
                }
            }
        }

        private static int getProper2PowNum(int minSize) {
            if (minSize > table2pow[table2pow.length - 1]) {
                throw new IllegalArgumentException(String.format("Array size: %d too large", minSize));
            }

            int low = 0, high = table2pow.length - 1;
            while (low <= high) {
                int mid = (low + high) / 2;
                if (minSize == table2pow[mid]) {
                    return table2pow[mid];
                } else if (minSize < table2pow[mid]) {
                    high = mid - 1;
                } else {
                    low = mid + 1;
                }
            }
            return table2pow[low];
        }

        private SortedSet<BytesRef> getTermsWithPrefix(BytesRef bytes) {
            if (isTopmost(bytes)) {
                return terms.tailSet(bytes, false);
            }

            return terms.subSet(bytes, false, nextBytesRef(bytes), false);
        }

        private static class Entry {
            BytesRef bytes;
            int node;

            Entry(BytesRef bytes, int node) {
                this.bytes = bytes;
                this.node = node;
            }
        }

        private static class AutoExpandIntArray {
            int[] data;

            AutoExpandIntArray(int initialSize, int fillValue) {
                data = new int[initialSize];
                if (fillValue != 0) {
                    Arrays.fill(data, fillValue);
                }
            }

            void expand(int minSize, int fillValue) {
                int[] newData = new int[getProper2PowNum(minSize)];
                System.arraycopy(data, 0, newData, 0, data.length);
                if (fillValue != 0) {
                    Arrays.fill(newData, data.length, newData.length, fillValue);
                }
                data = newData;
            }

            int set(int pos, int value, int fillValue) {
                if (pos >= data.length) {
                    expand(pos + 1, fillValue);
                }
                return data[pos] = value;
            }

            int get(int pos, int defaultValue) {
                if (pos >= data.length) {
                    expand(pos + 1, defaultValue);
                }
                return data[pos];
            }
        }

        public DAT build() {
            // BFS to iterate all nodes.
            final Queue<Entry> queue = new LinkedList<Entry>();
            final AutoExpandIntArray base = new AutoExpandIntArray(INITIAL_CAPACITY, 0);  // TODO: calc proper capacity.
            final AutoExpandIntArray check = new AutoExpandIntArray(INITIAL_CAPACITY, -1);

            queue.add(new Entry(EMPTY_BYTESREF, 0));
            while (!queue.isEmpty()) {
                final Entry elem = queue.poll();
                SortedSet<BytesRef> tailSet = getTermsWithPrefix(elem.bytes);
                final int length = elem.bytes.length;
                Iterator<BytesRef> itr = tailSet.iterator();
                TByteArrayList siblings = new TByteArrayList();
                final TByteSet termSet = new TByteHashSet();
                while (itr.hasNext()) {
                    BytesRef term = itr.next();
                    byte b = term.bytes[length];
                    if (term.length == length + 1) {
                        termSet.add(b);
                    }
                    if (!siblings.isEmpty() && siblings.get(siblings.size() - 1) == b) {
                        continue;
                    }
                    siblings.add(b);
                }
                if (siblings.isEmpty()) {
                    continue;
                }
                final int baseValue = base.set(elem.node, getBestBaseValue(check, elem.node, siblings), 0);
                final int parentNode = elem.node;
                siblings.forEach(new TByteProcedure() {
                    @Override
                    public boolean execute(byte b) {
                        byte[] newBytes = new byte[length + 1];
                        System.arraycopy(elem.bytes.bytes, elem.bytes.offset, newBytes, 0, length);
                        newBytes[length] = b;
                        queue.add(new Entry(new BytesRef(newBytes), baseValue + b));
                        if (termSet.contains(b)) {
                            check.set(baseValue + b, parentNode, -1);
                        } else {
                            check.set(baseValue + b, (parentNode | 0x80000000), -1);
                        }
                        return true;
                    }
                });
            }

            return new DAT(base.data, check.data);
        }
    }
}


UT部分:

package com.dp.junhao.jhsegmenter;

import org.apache.lucene.util.CharsRef;
import org.apache.lucene.util.UnicodeUtil;
import org.testng.annotations.Test;

import static junit.framework.Assert.assertEquals;
import static junit.framework.Assert.assertFalse;
import static org.testng.Assert.assertTrue;

/**
 * Created by junhao.zhang on 17/8/10.
 */
public class DATTest {
    @Test
    public void testSimpleDAT() {
        DAT.Builder builder = new DAT.Builder();
        builder.addTerm("ABC").addTerm("AC").addTerm("ACE")
                .addTerm("ACFF").addTerm("AD").addTerm("BBC").addTerm("CD").addTerm("CF").addTerm("ZQ");
        DAT dat = builder.build();
        assertTrue(dat.containsTerm("ABC"));
        assertTrue(dat.containsTerm("AC"));
        assertTrue(dat.containsTerm("ACE"));
        assertTrue(dat.containsTerm("ACFF"));
        assertTrue(dat.containsTerm("AD"));
        assertTrue(dat.containsTerm("BBC"));
        assertTrue(dat.containsTerm("CD"));
        assertTrue(dat.containsTerm("CF"));
        assertTrue(dat.containsTerm("ZQ"));
        assertFalse(dat.containsTerm("BB"));
    }

    @Test
    public void testSimpleDATWhiteBox() {
        DAT.Builder builder = new DAT.Builder();
        builder.addTerm("ABC").addTerm("AC").addTerm("ACE")
                .addTerm("ACFF").addTerm("AD").addTerm("BBC").addTerm("CD").addTerm("CF").addTerm("ZQ");
        DAT dat = builder.build();

        int[] base = dat.getBaseArray();
        int[] check = dat.getCheckArray();

        assertEquals(base.length, 32);
        assertEquals(check.length, 32);

        assertEquals(base[0], -64);
        assertEquals(base[1], -62);
        assertEquals(base[2], -59);
        assertEquals(base[3], -60);
        assertEquals(base[4], -58);
        assertEquals(base[5], -58);
        assertEquals(base[6], 0);
        assertEquals(base[7], -54);
        assertEquals(base[8], 0);
        assertEquals(base[9], 0);
        assertEquals(base[10], 0);
        assertEquals(base[11], 0);
        assertEquals(base[12], -56);
        assertEquals(base[13], 0);
        assertEquals(base[14], 0);
        assertEquals(base[15], 0);
        assertEquals(base[16], 0);
        assertEquals(base[17], 0);
        assertEquals(base[18], 0);
        assertEquals(base[19], 0);
        assertEquals(base[20], 0);
        assertEquals(base[21], 0);
        assertEquals(base[22], 0);
        assertEquals(base[23], 0);
        assertEquals(base[24], 0);
        assertEquals(base[25], 0);
        assertEquals(base[26], -54);
        assertEquals(base[27], 0);
        assertEquals(base[28], 0);
        assertEquals(base[29], 0);
        assertEquals(base[30], 0);
        assertEquals(base[31], 0);

        assertEquals(check[0], -1);
        assertEquals(check[1], (0 | 0x80000000));
        assertEquals(check[2], (0 | 0x80000000));
        assertEquals(check[3], (0 | 0x80000000));
        assertEquals(check[4], (1 | 0x80000000));
        assertEquals(check[5], 1);  // ABC
        assertEquals(check[6], 1);  // AD
        assertEquals(check[7], (2 | 0x80000000));
        assertEquals(check[8], 3);  // CD
        assertEquals(check[9], 4);  // ABC
        assertEquals(check[10], 3);  // CF
        assertEquals(check[11], 5);  // ACE
        assertEquals(check[12], (5 | 0x80000000));
        assertEquals(check[13], 7);  // BBC
        assertEquals(check[14], 12);  // ACFF
        assertEquals(check[15], -1);
        assertEquals(check[16], -1);
        assertEquals(check[17], -1);
        assertEquals(check[18], -1);
        assertEquals(check[19], -1);
        assertEquals(check[20], -1);
        assertEquals(check[21], -1);
        assertEquals(check[22], -1);
        assertEquals(check[23], -1);
        assertEquals(check[24], -1);
        assertEquals(check[25], -1);
        assertEquals(check[26], (0 | 0x80000000));
        assertEquals(check[27], 26);  // ZQ
        assertEquals(check[28], -1);
        assertEquals(check[29], -1);
        assertEquals(check[30], -1);
        assertEquals(check[31], -1);

        assertEquals(dat.getCompactRate(), 0.5);
    }
}


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理 dat批处理
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值