HIVE UDF —— matchNWords

最近因为业务需求,需要设计一个UDF去统计两个字符串中单词的匹配次数。其中单词(word)的定义是这样的:中文以单个字作为word,英文是以空格分割连续的一串字母作为word。

一、第一步:分词

**
     * Split String to Words. 英文以一个单词作为word, 中文以单个字作为word.
     *
     * @param src
     * @return ArrayList<String>
     */
    public static ArrayList<String> splitWords(String src) {
        // Step1: Split String to Character. Noted: Chinese takes two bytes.
        char[] charArray = src.toCharArray();

        ArrayList<String> result = new ArrayList<>();

        // Step2: split Character to words.
        for (int i = 0, j = 0, e = 0; i < charArray.length; ) {
            // 如果是空格(全角以及半角)
            if (charArray[i] == '\u0020' || charArray[i] == '\u3000') {
                i++;
                j++;
                e = i;
            }
            // 如果是英文字符
            else if (charArray[i] >= 0x0000 && charArray[i] <= 0x00FF) {
                j++;
                // 数组下标越界处理
                if (j >= charArray.length) {
                    StringBuffer sb = new StringBuffer();
                    for (int t = e; t < j; t++) {
                        sb.append(charArray[t]);
                    }
                    result.add(sb.toString());
                    i++;
                    e = i;
                }
                // 如果下一位也是英文字符, 除了英文空格
                else if (charArray[j] != '\u0020' && charArray[j] >= 0x0000 && charArray[j] <= 0x00FF) {
                    i++;
                }
                // 其他情况: 中文或空格. 到此,单个英文单词分词完毕
                else {
                    StringBuffer sb = new StringBuffer();
                    for (int t = e; t < j; t++) {
                        sb.append(charArray[t]);
                    }
                    result.add(sb.toString());
                    e = i;
                    i++;
                }
            }
            // 其他情况:中文
            else {
                StringBuffer sb = new StringBuffer();
                sb.append(charArray[i]);
                result.add(sb.toString());
                i++;
                j++;
                e = i;
            }
        }
        return result;
    }

这里实现思路如下:

  1. 变量i用来记录字符串扫描位置,变量j用来记录变量i的下一位,变量e用来记录英文单词的起始位。
  2. 若i位是英文字符,则j++,判断j位的字符,如果是英文,则i++(变量i用来记录英文单词结束位),如果是其他的,则单个英文分词结束,取e到j作为单词,然后e=i,i++
  3. 若i位是中文,则直接将值存储,i++,j++,e=i
  4. 若i位是空格,不存值,i++,j++,e=i

测试如下:

    @Test
    public void testSplitWords() {
        ArrayList<String> result = UDFMatchNWords.splitWords("China 中国电力公司CN");
        for (String s : result) {
            System.out.println(s);
        }
        Assert.assertEquals(8, result.size());

        System.out.println("---------------------------");

        ArrayList<String> result2 = UDFMatchNWords.splitWords("中国\u3000China电力公司 之 上海分公司 BB");
        for (String s : result2) {
            System.out.println(s);
        }
        Assert.assertEquals(14, result2.size());
    }

二、第二步,Array转HashMap

因为对于重复单词(word),需要取最小匹配次数。因此这里直接使用Collectors.toMap方法,将word作为key,出现次数作为value,然后当key冲突时,将value相加。

参考文档:Collectors.toMap 使用技巧 (List 转 Map超方便)

    /**
     * Convert ArrayList To HashMap. The Key is word, and the value is the count of word.
     *
     * @param list
     * @return HashMap<String, Integer>
     */
    public static HashMap<String, Integer> convertListToHashMap(List<String> list) {
        HashMap<String, Integer> map = list.stream().collect(Collectors.toMap(
                String::toString, v -> 1, (v1, v2) -> (v1 + v2), HashMap::new
        ));
        return map;
    }

测试

    @Test
    public void testConvertListToHashMap() {
        List<String> list = Arrays.asList("阿", "巴", "里", "巴", "巴");
        Map<String, Integer> map = UDFMatchNWords.convertListToHashMap(list);
        map.forEach((k, v) -> System.out.println("word:" + k + ", count:" + v));
        Assert.assertEquals(new Integer(3), map.get("巴"));
        Assert.assertEquals(new Integer(1), map.get("阿"));
    }

 三、匹配

当实现了splitWords和convertListToHashMap方法后,evaluate里只需调用这两个方法,然后做匹配即可。这里直接放完整代码。

package com.scb.dss.udf;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;

@Description(name = "matchNWords",
        value = "_FUNC_(str1, str2, n) - Return TRUE if there are n words matched between str1 and str2")
public class UDFMatchNWords extends UDF {

    /**
     * Return TRUE if there are n words matched between str1 and str2
     * @param str1
     * @param str2
     * @param n
     * @return
     */
    public boolean evaluate(String str1, String str2, Integer n) {
        // Step1: split words
        ArrayList<String> s1 = splitWords(str1);
        ArrayList<String> s2 = splitWords(str2);

        // Step2: match n words
        // 如果str1或者str2的单词小于n个, 则直接返回false
        if (s1.size() < n || s2.size() < n) {
            return false;
        }

        // convert to HashMap
        HashMap<String, Integer> map1 = convertListToHashMap(s1);
        HashMap<String, Integer> map2 = convertListToHashMap(s2);

        int matchCnt = 0;

        for (String s : map1.keySet()) {
            if (map2.containsKey(s)) {
                matchCnt += Math.min(map1.get(s), map2.get(s));
                // 短路原则,优化算法
                if (matchCnt == n) {
                    return true;
                }
            }
        }
        return false;
    }

    /**
     * return the count of match words between str1 and str2
     * @param str1
     * @param str2
     * @return match words count
     */
    public Integer evaluate(String str1, String str2) {
        if (str1 == null || str2 == null) {
            return null;
        }

        // Step1: split words
        ArrayList<String> s1 = splitWords(str1);
        ArrayList<String> s2 = splitWords(str2);

        // Step2: match n words
        // convert to HashMap
        HashMap<String, Integer> map1 = convertListToHashMap(s1);
        HashMap<String, Integer> map2 = convertListToHashMap(s2);

        int matchCnt = 0;

        for (String s : map1.keySet()) {
            if (map2.containsKey(s)) {
                matchCnt += Math.min(map1.get(s), map2.get(s));
            }
        }
        return matchCnt;
    }

    /**
     * Convert ArrayList To HashMap. The Key is word, and the value is the count of word.
     *
     * @param list
     * @return HashMap<String, Integer>
     */
    public static HashMap<String, Integer> convertListToHashMap(List<String> list) {
        HashMap<String, Integer> map = list.stream().collect(Collectors.toMap(
                String::toString, v -> 1, (v1, v2) -> (v1 + v2), HashMap::new
        ));
        return map;
    }

    /**
     * Split String to Words. 英文以一个单词作为word, 中文以单个字作为word.
     *
     * @param src
     * @return ArrayList<String>
     */
    public static ArrayList<String> splitWords(String src) {
        // Step1: Split String to Character. Noted: Chinese takes two bytes.
        char[] charArray = src.toCharArray();

        ArrayList<String> result = new ArrayList<>();

        // Step2: split Character to words.
        for (int i = 0, j = 0, e = 0; i < charArray.length; ) {
            // 如果是空格(全角以及半角)
            if (charArray[i] == '\u0020' || charArray[i] == '\u3000') {
                i++;
                j++;
                e = i;
            }
            // 如果是英文字符
            else if (charArray[i] >= 0x0000 && charArray[i] <= 0x00FF) {
                j++;
                // 数组下标越界处理
                if (j >= charArray.length) {
                    StringBuffer sb = new StringBuffer();
                    for (int t = e; t < j; t++) {
                        sb.append(charArray[t]);
                    }
                    result.add(sb.toString());
                    i++;
                    e = i;
                }
                // 如果下一位也是英文字符, 除了英文空格
                else if (charArray[j] != '\u0020' && charArray[j] >= 0x0000 && charArray[j] <= 0x00FF) {
                    i++;
                }
                // 其他情况: 中文或空格. 到此,单个英文单词分词完毕
                else {
                    StringBuffer sb = new StringBuffer();
                    for (int t = e; t < j; t++) {
                        sb.append(charArray[t]);
                    }
                    result.add(sb.toString());
                    e = i;
                    i++;
                }
            }
            // 其他情况:中文
            else {
                StringBuffer sb = new StringBuffer();
                sb.append(charArray[i]);
                result.add(sb.toString());
                i++;
                j++;
                e = i;
            }
        }
        return result;
    }
}

测试类

package com.scb.dss.udf;

import org.junit.Assert;
import org.junit.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

public class UDFMatchNWordsTest {

    private final UDFMatchNWords matchNWords = new UDFMatchNWords();

    @Test
    public void testSplitWords() {
        ArrayList<String> result = UDFMatchNWords.splitWords("China 中国电力公司CN");
        for (String s : result) {
            System.out.println(s);
        }
        Assert.assertEquals(8, result.size());

        System.out.println("---------------------------");

        ArrayList<String> result2 = UDFMatchNWords.splitWords("中国\u3000China电力公司 之 上海分公司 BB");
        for (String s : result2) {
            System.out.println(s);
        }
        Assert.assertEquals(14, result2.size());
    }

    @Test
    public void testConvertListToHashMap() {
        List<String> list = Arrays.asList("阿", "巴", "里", "巴", "巴");
        Map<String, Integer> map = UDFMatchNWords.convertListToHashMap(list);
        map.forEach((k, v) -> System.out.println("word:" + k + ", count:" + v));
        Assert.assertEquals(new Integer(3), map.get("巴"));
        Assert.assertEquals(new Integer(1), map.get("阿"));
    }

    @Test
    public void testEvaluate() {
        // 当word匹配上时,取最小count作为匹配次数
        Assert.assertEquals(true, matchNWords.evaluate("阿里巴巴", "巴土巴士", 2));
        Assert.assertEquals(true, matchNWords.evaluate("阿里巴巴", "巴士公司", 1));

        // 如果str1或者str2的单词小于n个, 则直接返回false
        Assert.assertEquals(false, matchNWords.evaluate("阿里巴巴", "阿里巴巴有限公司", 5));

        // 复杂匹配: 含中英文
        Assert.assertEquals(true, matchNWords.evaluate("China 国网天津电力公司 Co., Ltd.", "国网电力公司 China 天津分公司 Co., Ltd.", 11));
    }

    @Test
    public void testEvaluate2() {
        Assert.assertEquals(new Integer(2), matchNWords.evaluate("hello hi world", "hi world"));
    }
}

四、发布UDF

参考上一节 Hive UDF<用户自定义函数>入门 将UDF发布到HIVE上。

select matchNWords('Hello Hi World', 'Hello World');

select matchNWords('Hello Hi World', 'Hello World', 2);

 Hive UDF 的 evaluate 是支持重载的,具体执行过程参考:HiveUDF的evaluate方法使用分析

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值