搜索排序LambdaMART中Lambda的计算过程java版本

本文详细介绍如何使用Java实现Lambda表达式计算文档排序的DCG(Discounted Cumulative Gain)和NDCG(Normalized Discounted Cumulative Gain),涉及query_dcg、ideal_dcg、query_swap_dcg和deltaNDCG等关键函数,以及lambda值的动态调整算法。
摘要由CSDN通过智能技术生成

Lambdajava实现

这里只告诉说明Lambda的计算,后面的mart大家随便用其他的都可以,这里详细写了Lambda是如何计算得来,java版本的实现。代码如下:
样本的格式如下:
在这里插入图片描述

public class LambdaCalculate {

/**
 * @param position:doc在一次query序列中的位置
 * @param lable:doc在一次query中的等级,点击?加购?下单?
 * @return :返回这个文档的dcg值
 * */
public static Double doc_dcg(Integer position,Integer lable){
    return (Math.pow(2, lable)-1)/(Math.log(position+1)/Math.log(2));
}

/**
 * @param lists:query的一个序列
 * @return :query的整体dcg,就是各个doc的dcg累加
 * */
public static Double query_dcg(List<QueryUnit<String,Integer,Integer,Double>> lists){
    Double query_dcg = 0.0D;
    for(int i=0;i<lists.size();i++){
        QueryUnit<String,Integer,Integer,Double> unit = lists.get(i);
        query_dcg +=doc_dcg(i+1, unit.getLabelValue());
    }
    return query_dcg;
}

/**
 * @param lists:query的一个序列
 * @return :计算理想序列下的DCG
 * */
public static Double ideal_dcg(List<QueryUnit<String,Integer,Integer,Double>> lists){
    List<QueryUnit<String,Integer,Integer,Double>> lists2 = new ArrayList<>();
    lists2.addAll(lists);
    Collections.sort(lists2, new Comparator<QueryUnit<String,Integer,Integer,Double>>() {
        @Override
        public int compare(QueryUnit<String,Integer,Integer,Double> o1, QueryUnit<String,Integer,Integer,Double> o2) {
            return o2.getLabelValue()-o1.getLabelValue();
        }
    });
    Double ideal_dcg = query_dcg(lists2);
    return ideal_dcg;
}

/**
 * @param lists:query的一个序列
 * @return :计算NDCG,首先要计算理想状态下dcg1,然后计算现实中的dcg2,ndcg = dcg2/dcg1
 * */
public static Double query_ndcg(List<QueryUnit<String,Integer,Integer,Double>> lists){
    //现实中的query dcg
    Double real_dcg = query_dcg(lists);
    //针对lists数据,按照lable进行排序
    Double ideal_dcg = ideal_dcg(lists);
    return real_dcg/ideal_dcg;
}

/**
 * @param lists:query的一个序列
 * @param  swapMap:交换位置的两个doc的位置数据比如(1,4)(4,1) 1与4互换位置
 * @return :返回互换后的query dcg值
 * */
public static Double query_swap_dcg(List<QueryUnit<String,Integer,Integer,Double>> lists, Map<Integer,Integer> swapMap){
    Double query_swap_dcg = 0.0D;
    for(int i=0;i<lists.size();i++){
        Integer swap_position = swapMap.get(i+1);
        if(swap_position!=null){
            Integer swap_lable = lists.get(swap_position-1).getLabelValue();
            query_swap_dcg += doc_dcg(i+1, swap_lable);
        }else{
            query_swap_dcg+=doc_dcg(i+1, lists.get(i).getLabelValue());
        }
    }
    return query_swap_dcg;
}

/**
 * @param lists:query的一个序列
 * @param swapMap:交换位置的两个doc的位置数据比如(1,4)(4,1) 1与4互换位置
 * @return :计算NDCG,首先要计算理想状态下dcg1,然后计算现实中的dcg2,ndcg = dcg2/dcg1
 * */
public static Double query_swap_ndcg(List<QueryUnit<String,Integer,Integer,Double>> lists,Map<Integer,Integer> swapMap){
    //交换位置后的现实dcg
    Double real_swap_dcg = query_swap_dcg(lists, swapMap);
    //理想状态下的dcg
    Double ideal_dcg = ideal_dcg(lists);
    return real_swap_dcg/ideal_dcg;
}

/**
 * @param lists:query的一个序列
 * @param swapMap:交换位置的两个doc的位置数据比如(1,4)(4,1) 1与4互换位置
 * @return :返回deltaNDCG
 * */
public static Double deltaNDCG(List<QueryUnit<String,Integer,Integer,Double>> lists,Map<Integer,Integer> swapMap){

    Double query_swap_ndcg = query_swap_ndcg(lists, swapMap);

    Double query_ndcg = query_ndcg(lists);

    return Math.abs(query_ndcg-query_swap_ndcg);
}

/**
 * @param si:doci预测得分,一般在第一次模型没有的时候,都是0
 * @param sj:docj预测得分,一般在第一次模型没有的时候,都是0
 * @param sigma:这个值只是影响曲线的陡峭度,默认这里我选1
 * @return :返回值betaij ,表示doci比docj差的概率
 * */
public static Double betaij(Integer sigma,Double si,Double sj){
    if(sigma == null){
        sigma = 1;
    }
    return 1/(1+Math.pow(Math.E, sigma*(si-sj)));
}

/**
 * @param lists:query的一个序列
 * @param currentIndex:当前要交换位置的那个doc的位置
 * @param swapMap: 交换位置的两个doc的位置数据比如(1,4)(4,1) 1与4互换位置
 * @param sigma:这个值只是影响曲线的陡峭度,默认这里我选1
 * @return :返回该doc交换一次后的lambda值.
 * */
public static Double lambdaij(List<QueryUnit<String,Integer,Integer,Double>> lists,
                              Integer currentIndex,Map<Integer,Integer> swapMap,Integer sigma){
    Integer targetIndex = swapMap.get(currentIndex);
    QueryUnit<String,Integer,Integer,Double> currentUnit = lists.get(currentIndex-1);
    QueryUnit<String,Integer,Integer,Double> targetUnit = lists.get(targetIndex-1);
    Integer currentLable = currentUnit.getLabelValue();
    Integer targetLable = targetUnit.getLabelValue();
    Double lambdaij = 0.0D;
    if(currentLable<=targetLable){
        lambdaij = -betaij(sigma,currentUnit.getDocscore(),targetUnit.getDocscore())*deltaNDCG(lists, swapMap);
    }else{
        lambdaij = betaij(sigma,currentUnit.getDocscore(),targetUnit.getDocscore())*deltaNDCG(lists, swapMap);
    }
    return lambdaij;
}

/**
 * @param lists:query的一个序列
 * @param currentIndex:当前要交换位置的那个doc的位置
 * @param sigma:这个值只是影响曲线的陡峭度,默认这里我选1
 * @return :返回一个doc与各个位置都交互完后的lambda
 * */
public static Double doc_lambdaij(List<QueryUnit<String,Integer,Integer,Double>> lists,
                                  Integer currentIndex,Integer sigma){
    Double doc_lambdaij = 0.0D;
    for(int i=0;i<lists.size();i++){
        Integer iindex = i+1;
        QueryUnit<String,Integer,Integer,Double> currentUint = lists.get(currentIndex-1);
        QueryUnit<String,Integer,Integer,Double> targetUint = lists.get(i);
        Integer currentLable = currentUint.getLabelValue();
        Integer tergetLable = targetUint.getLabelValue();
        if(currentLable ==tergetLable){
            doc_lambdaij+=0.0D;
        }else{
            Map<Integer,Integer> swapMap = new HashMap<>();
            swapMap.put(currentIndex, iindex);
            swapMap.put(iindex, currentIndex);
            doc_lambdaij += lambdaij(lists, currentIndex, swapMap, sigma);
        }
    }
    return doc_lambdaij;
}

/**
 * @param lists:query的一个序列,序列的单元是QueryUnit,第一个是query_id,第二是:doc位置,从1开始.第三个是:当前doc的lable。第四个是:当前文档分数
 * @param sigma: 这个值只是影响曲线的陡峭度,默认这里我选1
 * @return :返回了一个query下整个序列的lambda
 * */
public static Map<Integer,Double> query_lambdaij(List<QueryUnit<String,Integer,Integer,Double>> lists,Integer sigma){
    Map<Integer,Double> query_lambdaMap = new HashMap<>();
    for(int i=0;i<lists.size();i++){
        QueryUnit<String,Integer,Integer,Double> unit = lists.get(i);
        Integer position = unit.getPosition();
        Double doc_lambda =  doc_lambdaij(lists, i+1, sigma);
        query_lambdaMap.put(position, doc_lambda);
    }
    return query_lambdaMap;
}

/**
 * @param dataList:样本训练集的数据集
 * @param querySize:样本query一个批次是多少条样本
 * @return :返回更新了lambda值后的list
 * */
public static List<List<String>> updateLambda(List<List<String>> dataList,Integer querySize){
    Integer rowLength = dataList.get(0).size();
    List<QueryUnit<String,Integer,Integer,Double>> queryList = new ArrayList<>();
    List<Integer> idList = new ArrayList<>();
    for(int i=0;i<dataList.size();i++){
        List<String> row = dataList.get(i);
        idList.add(i);
        Integer id = Integer.parseInt(row.get(0));
        QueryUnit<String,Integer,Integer,Double> unit = new QueryUnit<>();
        unit.setQueryId(row.get(2));
        Integer position = id%10==0?10:id%10;
        unit.setPosition(position);
        unit.setLabelValue(Integer.parseInt(row.get(1)));
        unit.setDocscore(Double.parseDouble(row.get(rowLength-2)));
        queryList.add(unit);
        if(id%querySize==0){//每到一个批次结束的时候,就开始进行计算lambda,并进行lambda更新
            Map<Integer,Double> ndcg = LambdaCalculate.query_lambdaij(queryList, 1);
            for(int j=0;j<ndcg.size();j++){
                Integer idIndex = idList.get(j);
                dataList.get(idIndex).set(rowLength-1,ndcg.get(j+1).toString());
            }
            idList.clear();
            queryList.clear();
        }

    }
    return dataList;
}

/**
 * @param dataList:传入的是训练数据集,特征列数15列,0是样本ID,1是样本标签 2是样本query_id,3-12是特征,13是样本得分 14存lambda值
 * @param learningRate:学习率
 * @param node:传入叶子节点
 * @return :返回数据训练集,并把数据集的score列进行更新完毕。
 * */
public static List<List<String>> updateScore(TreeNode node,List<List<String>> dataList,Double learningRate){
    List<LeafUnit<Integer,Double>> valueList = node.getList();
    Integer rowLength = dataList.get(0).size();
    Double sum =0.0D;
    Integer count = 0;
    for(LeafUnit<Integer,Double> unit:valueList){
        sum += unit.getValue();
        count++;
    }
    Double gama = sum/count;
    Double increment = gama*learningRate;
    for(LeafUnit<Integer,Double> unit:valueList){
        //下面这句话的意思,就是我们把新创建的这颗树叶子节点的lambda均值乘以一个学习率后,加到上一颗树的得分上,第一棵树上一个树是0,所以他们的得分值都是0
        Double newValue = Double.parseDouble(dataList.get(unit.getIndex()).get(rowLength-2))+increment;
        dataList.get(unit.getIndex()).set(rowLength-2,newValue.toString());
    }
    return dataList;
}

}

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值