LibRec 学习笔记(九):如何利用已有轮子 LibRec 库实现自己的推荐算法?

在我看来,东北大学郭贵兵老师的 LibRec 推荐算法开源库真的是帮了我这种学术小菜鸟很多忙,它帮助我们复现了很多学术论文的方法,同时给予我们很大的方便去自主复现顶会论文。那作为学术小菜鸟的我如果想要写自己的推荐算法,怎么办呢?答案是我们完全也可以直接利用这个开源库!

那下面以我自己粗浅的理解配合官方文档给出指引,如何利用已有轮子 LibRec 去实现自己的推荐算法。

一、LibRec 中推荐的流程

在这里插入图片描述
在讲如何去实现自己的推荐算法之前,我想要介绍一下 LibRec 中整个推荐的流程:

  • 首先我们需要得到算法所需要的数据,比如用户行为数据(如最常用的用户-物品-评分数据,用户的隐性反馈数据)、附加数据(如社交网络数据、地理位置数据、物品内容数据)等等
  • 其次我们需要对数据进行相应的处理,这里的处理包括转化数据格式(比如text、arff数据格式)、根据要求去划分数据集(比如按比例划分、留一划分、k折划分、指定测试集和数据集)、处理附加数据等。
  • 然后我们才能使用推荐算法利用训练集进行训练,把得到的结果利用测试集去进行评估推荐算法的好坏,其中如果需要计算相似度的算法,需要考虑相似度计算模块(比如欧式距离等10余种)
  • 最后如果不需要进行过滤数据的话,把结果保存下来

整个推荐流程最关键的一步在于训练部分,也就是使用了什么推荐算法?其他部分的操作在所有算法实现上来说都是一模一样的,这些部分的轮子其实我们可以不用造,对相应的配置项进行配置直接用就好了。

所以对于实现一个推荐算法只要写好整个算法的train()方法即可,所以想来是不是会很方便!

二、LibRec 中的 6 个抽象类

在 LibRec 中总共设计了 6 个抽象的基类方便我们去继承,从而实现不同类型的推荐算法,分别是:

  • Abstract Recommender 抽象推荐算法
  • Matrix Recommender 基于矩阵抽象推荐算法
  • Matrix Probabilistic Graphical Recommender 基于矩阵概率图模型的抽象推荐算法
  • Matrix Factorization Recommender 矩阵分解抽象推荐算法
  • Factorization Machine Recommender 因子分解抽象推荐算法
  • Social Recommender 社交抽象推荐算法
  • Tensor Recommender张量抽象推荐算法

目前在 LibRec 中已有70+ 的推荐算法都是在这些基类的基础上去进行设计的,它们包括以下7个类别,分别是基准算法、协调过滤算法、基于内容的推荐算法、情景感知算法、深度学习算法、混合算法以及其他扩展算法

所以如果你需要在 LibRec 中实现自己的算法,首先需要按照自己算法所属的类别去继承相应的抽象类,并按要求去实现相应的抽象方法,也可以按自己的需要去重写抽象类中的方法。

三、实现自己的推荐算法

以继承 Abstract Recommender 抽象方法为例,下面是该方法的代码(可以先略看):

package net.librec.recommender;
import com.google.common.collect.BiMap;
import net.librec.common.LibrecException;
import net.librec.conf.Configuration;
import net.librec.data.DataModel;
import net.librec.job.progress.ProgressBar;
import net.librec.recommender.item.*;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/**
 * Abstract Recommender Methods
 *
 * @author WangYuFeng and Wang Keqiang
 */
public abstract class AbstractRecommender implements Recommender {
    /**
     * LOG
     */
    protected final Log LOG = LogFactory.getLog(this.getClass());

    /**
     * is ranking or rating
     */
    protected boolean isRanking;

    /**
     * topN
     */
    protected int topN;

    /**
     * conf
     */
    protected Configuration conf;

    /**
     * RecommenderContext
     */
    protected RecommenderContext context;

    /**
     * early-stop criteria
     */
    protected boolean earlyStop;

    /**
     * verbose
     */
    protected static boolean verbose = true;

    /**
     * objective loss
     */
    protected double loss, lastLoss = 0.0d;

    /**
     * whether to adjust learning rate automatically
     */
    protected boolean isBoldDriver;

    /**
     * decay of learning rate
     */
    protected float decay;

    /**
     * report the training progress
     */
    protected ProgressBar progressBar;

    /**
     * user Mapping Data
     */
    public BiMap<String, Integer> userMappingData;

    /**
     * item Mapping Data
     */
    public BiMap<String, Integer> itemMappingData;

    /**
     * setup
     *
     * @throws LibrecException if error occurs during setup
     */
    protected void setup() throws LibrecException {
        conf = context.getConf();//通过 RecommenderContext类获取所有的配置项到 conf 变量中,而这些变量就是 
        //librec-default.properties和具体算法中的配置项,比如sbpr-test.properties
        isRanking = conf.getBoolean("rec.recommender.isranking");//获取是否进行排序配置项,比如topN任务这一项都有
        if (isRanking) {
            topN = conf.getInt("rec.recommender.ranking.topn", 10);//TopN值,默认是10
            if (this.topN <= 0) {
                throw new IndexOutOfBoundsException("rec.recommender.ranking.topn should be more than 0!");
            }
        }
        earlyStop = conf.getBoolean("rec.recommender.earlystop", false);//是否进行早停策略,默认值是false
        verbose = conf.getBoolean("rec.recommender.verbose", true);//是否输出打印信息,就是控制台输出的那些信息,默认值是true

        userMappingData = getDataModel().getUserMappingData();//得到用户隐射数据(具体用途,暂时不详)
        itemMappingData = getDataModel().getItemMappingData();//得到物品隐射数据(具体用途,暂时不详)

        if (verbose) {//如果可以输出打印消息,则设置进度条的大小
            progressBar = new ProgressBar(100, 100);
        }
    }

    /**
     * train Model
     *
     * @throws LibrecException if error occurs during training model
     */
    protected abstract void trainModel() throws LibrecException;

    /**
     * recommend
     *
     * @param context recommender context
     * @throws LibrecException if error occurs during recommending
     */
    public void train(RecommenderContext context) throws LibrecException {
        this.context = context;
        setup();//基本就是配置项的读取等操作
        LOG.info("Job Setup completed.");
        trainModel();//调用具体推荐算法的训练方法
        LOG.info("Job Train completed.");
        cleanup();
    }

    /**
     * cleanup
     *
     * @throws LibrecException if error occurs during cleanup
     */
    protected void cleanup() throws LibrecException {

    }

    /**
     * (non-Javadoc)
     *
     * @see net.librec.recommender.Recommender#loadModel(String)
     */
    @Override
    public void loadModel(String filePath) {

    }

    /**
     * (non-Javadoc)
     *
     * @see net.librec.recommender.Recommender#saveModel(String)
     */
    @Override
    public void saveModel(String filePath) {

    }

    /**
     * get Context
     *
     * @return recommender context
     */
    protected RecommenderContext getContext() {
        return context;
    }

    /**
     * set Context
     *
     * @param context recommender context
     */
    public void setContext(RecommenderContext context) {
        this.context = context;
    }

    /**
     * get Data Model
     *
     * @return data model
     */
    public DataModel getDataModel() {
        return context.getDataModel();
    }

    /**
     * get Recommended List
     *
     * @return Recommended List
     */
    //得到推荐的结果
    public List<RecommendedItem> getRecommendedList(RecommendedList recommendedList) {

        if (recommendedList != null && recommendedList.size() > 0) {
            List<RecommendedItem> userItemList = new ArrayList<>();
            Iterator<ContextKeyValueEntry> recommendedEntryIter = recommendedList.iterator();
            if (userMappingData != null && userMappingData.size() > 0 && itemMappingData != null && itemMappingData.size() > 0) {
                BiMap<Integer, String> userMappingInverse = userMappingData.inverse();
                BiMap<Integer, String> itemMappingInverse = itemMappingData.inverse();
                while (recommendedEntryIter.hasNext()) {
                    ContextKeyValueEntry contextKeyValueEntry = recommendedEntryIter.next();
                    if (contextKeyValueEntry != null) {
                        String userId = userMappingInverse.get(contextKeyValueEntry.getContextIdx());
                        String itemId = itemMappingInverse.get(contextKeyValueEntry.getKey());
                        if (StringUtils.isNotBlank(userId) && StringUtils.isNotBlank(itemId)) {
                            userItemList.add(new GenericRecommendedItem(userId, itemId, contextKeyValueEntry.getValue()));
                        }
                    }
                }
                return userItemList;
            }
        }
        return null;
    }

    /**
     * Post each iteration, we do things:
     * <ol>
     * <li>print debug information</li>
     * <li>check if converged</li>
     * <li>if not, adjust learning rate</li>
     * </ol>
     *
     * @param iter current iteration
     * @return boolean: true if it is converged; false otherwise
     * @throws LibrecException if error occurs
     */
    protected boolean isConverged(int iter) throws LibrecException {
        float delta_loss = (float) (lastLoss - loss);
        // 如果verbose为真,输出信息
        if (verbose) {
            String recName = getClass().getSimpleName();
            String info = recName + " iter " + iter + ": loss = " + loss + ", delta_loss = " + delta_loss;
            LOG.info(info);
        }
        //判断是否有异常
        if (Double.isNaN(loss) || Double.isInfinite(loss)) {
        	//LOG.error("Loss = NaN or Infinity: current settings does not fit the recommender! Change the settings and try again!");
            throw new LibrecException("Loss = NaN or Infinity: current settings does not fit the recommender! Change the settings and try again!");
        }
        //判断是否收敛
        return Math.abs(delta_loss) < 1e-5;
    }

    public void updateProgress(int currentPoint) {
        if (verbose) {
            conf.setInt("train.current.progress", currentPoint);
            progressBar.showBarByPoint(conf.getInt("train.current.progress"));
        }
    }
}

如果我们的推荐算法继承的是这个 Abstract Recommender 抽象类的话,我们实现一个自己的推荐算法的大致流程如下:

  1. Override 并且重新写 setup 方法
    setup 方法完成的任务主要是对算法成员变量的初始化,例如从配置文件中读取参数的操作可以写在这里,具体细节可以参考这篇博客中讲到的setup方法。 当然,这个步骤是可选步骤,但如果要重新写setup方法的时候,需要调用原抽象类中的setup方法,第一行用super.setup() ,保证算法的基本参数得到初始化。

  2. 实现 trainModel 方法
    trainModel 方法完成的任务是算法模型的训练,例如模型的损失函数利用梯度下降进行训练的过程,也就是需要我们写模型的地方!!!在基类 Abstract Recommender 中这个方法是为空的,方便后面继承的类进行改写、覆盖。

  3. 实现 predict 方法
    predict 方法完成的任务是,使用训练好的模型进行预测。例如对于评分预测算法,在 predict 方法中需要对测试集中的每个评分值进行预测,即对于给定的 user index 和 item index,使用模型预测它们之间的评分。

是的,正如你看到的那么简单,只要把以上三个方法写了就可以了。下面给出一个直接继承 Abstract Recommender 的推荐算法,比如 USG 算法(好吧,代码有点长,而且没有注释,大家不用细看,不过本意就是让大家看看以上三个方法在 USG 算法中是怎样写的,有没有发现除了这三个方法外,还有其他的方法,有些部分直接覆盖了基类的其他方法,有些部分是为了辅助以上三个方法的部分操作写的~)

所以如果读者选择好了要继承的基类的话,务必把这个基类是怎么写的,有哪些方法看一遍

package net.librec.recommender.poi;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.google.common.primitives.Ints;
import net.librec.common.LibrecException;
import net.librec.data.convertor.appender.LocationDataAppender;
import net.librec.data.structure.AbstractBaseDataEntry;
import net.librec.data.structure.LibrecDataList;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DataSet;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.recommender.AbstractRecommender;
import net.librec.recommender.item.KeyValue;
import net.librec.recommender.item.RecommendedList;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.*;
import java.nio.file.attribute.BasicFileAttributes;
import java.util.*;

/**
 * Ye M, Yin P, Lee W C, et al. Exploiting geographical influence for collaborative point-of-interest recommendation[C]//
 * International ACM SIGIR Conference on Research and Development in Information Retrieval. ACM, 2011:325-334.
 * @author Yuanyuan Jin
 *
 * ###special notes###
 * 1. prediction for all user, please set:
 * data.testset.path = poi/Gowalla/checkin/Gowalla_test.txt
 * and delete the para setting for "rec.limit.userNum" in usg.properties
 *
 * 2. prediction for small user set like userids in [0, 100],
 * in usg.properties, please set:
 * data.testset.path = poi/Gowalla/checkin/testDataFor101users.txt
 * rec.limit.userNum = 101
 * In EntropyEvaluator and NoveltyEvaluator, you also need to reset the variable "numUsers" = your limited userNum
 */
public class USGRecommender extends AbstractRecommender {
    private SequentialAccessSparseMatrix socialSimilarityMatrix;
    private SequentialAccessSparseMatrix userSimilarityMatrix;
    private SequentialAccessSparseMatrix socialMatrix;
    private SequentialAccessSparseMatrix trainMatrix;
    private SequentialAccessSparseMatrix testMatrix;
    /**
     * weight of the social score part
     */
    private double alpha;

    /**
     * weight of the geographical score part
     */
    private double beta;

    /**
     * tuning parameter in social similarity
     */
    private double eta;

    /**
     * linear coefficients for modeling the "log-log scale" power-law distribution
     */
    private double w0;
    private double w1;

    /**
     * number of pois
     */
    private int numPois;

    /**
     * number of users
     */
    private int numUsers;

    /**
     * for limiting test user cardinality
     */
    private int limitUserNum;

    private static final int BSIZE = 1024 * 1024;
    private String socialPath;
    private KeyValue<Double, Double>[] locationCoordinates;

    @Override
    protected void setup() throws LibrecException {
        super.setup();

        BiMap<Integer, String> userIds = this.userMappingData.inverse();
        BiMap<Integer, String> itemIds = this.itemMappingData.inverse();

        numPois = itemMappingData.size();
        numUsers = userMappingData.size();

        trainMatrix = (SequentialAccessSparseMatrix) getDataModel().getTrainDataSet();
        testMatrix = (SequentialAccessSparseMatrix) getDataModel().getTestDataSet();

        alpha = conf.getDouble("rec.alpha", 0.1d);
        beta = conf.getDouble("rec.beta", 0.1d);
        eta = conf.getDouble("rec.eta", 0.05d);
        //default value is numUsers
        limitUserNum = conf.getInt("rec.limit.userNum", numUsers);
        locationCoordinates = ((LocationDataAppender) getDataModel().getDataAppender()).getLocationAppender();
        userSimilarityMatrix = context.getSimilarity().getSimilarityMatrix().toSparseMatrix();
        socialPath = conf.get("dfs.data.dir") + "/" + conf.get("data.social.path");

        // for AUCEvaluator and nDCGEvaluator
        int[] numDroppedItemsArray = new int[numUsers];
        int maxNumTestItemsByUser = 0;
        for (int userIdx = 0; userIdx < numUsers; ++userIdx) {
            numDroppedItemsArray[userIdx] = numPois - trainMatrix.row(userIdx).getNumEntries();
            int numTestItemsByUser = testMatrix.row(userIdx).getNumEntries();
            maxNumTestItemsByUser = maxNumTestItemsByUser < numTestItemsByUser ? numTestItemsByUser : maxNumTestItemsByUser;
        }
        conf.setInts("rec.eval.auc.dropped.num", numDroppedItemsArray);
        conf.setInt("rec.eval.key.test.max.num", maxNumTestItemsByUser);

        // for EntropyEvaluator
        conf.setInt("rec.eval.item.num", testMatrix.columnSize());

        // for NoveltyEvaluator
        int[] itemPurchasedCount = new int[numPois];
        for (int itemIdx = 0; itemIdx < numPois; ++itemIdx) {
            int userNum = 0;
            int[] userArray = trainMatrix.column(itemIdx).getIndices();
            for (int userIdx : userArray) {
                if (userIdx >= 0 && userIdx < limitUserNum) {
                    userNum++;
                }
            }
            userArray = testMatrix.column(itemIdx).getIndices();
            for (int userIdx : userArray) {
                if (userIdx >= 0 && userIdx < limitUserNum) {
                    userNum++;
                }
            }
            itemPurchasedCount[itemIdx] = userNum;
        }
        conf.setInts("rec.eval.item.purchase.num", itemPurchasedCount);
    }

    @Override
    protected void trainModel() throws LibrecException {
        LOG.info("start buliding socialmatrix" + new Date());
        try {
            buildSocialMatrix(socialPath);
        } catch (IOException e) {
            e.printStackTrace();
        }

        LOG.info("start buliding socialSimilarityMatrix" + new Date());
        buildSocialSimilarity();

        LOG.info("start fitting the powerlaw distribution" + new Date());
        fitPowerLaw();
    }

    public double[] predictScore(int userIdx, int itemIdx) {
        //score array for three aspects: user preference, social influence and  geographical influence
        double[] predictScore = new double[]{0.0d, 0.0d, 0.0d};

        int[] userArray = trainMatrix.column(itemIdx).getIndices();
        List<Integer> userList = Ints.asList(userArray);

        /*---------start user preference socre calculation--------*/
        //iterator to iterate other similar users for each user
        Iterator<Vector.VectorEntry> userSimIter = userSimilarityMatrix.row(userIdx).iterator();

        //similarities between userIdx and its neighbors
        List<Double> neighborSimis = new ArrayList<>();
        while (userSimIter.hasNext()) {
            Vector.VectorEntry userRatingEntry = userSimIter.next();
            int similarUserIdx = userRatingEntry.index();
            if (!userList.contains(similarUserIdx)) {
                continue;
            }
            neighborSimis.add(userRatingEntry.get());
        }
        if (neighborSimis.size() == 0) {
            predictScore[0] = 0.0d;
        } else {
            double sum = 0.0d;
            for (int i = 0; i < neighborSimis.size(); i++) {
                sum += neighborSimis.get(i);
            }
            predictScore[0] = sum;
        }
        /*---------end user preference socre calculation--------*/

        /*---------start social influence socre calculation--------*/
        //social similarities between userIdx and its social neighbors
        List<Double> socialNeighborSimis = new ArrayList<>();
        Iterator<Vector.VectorEntry> friendIter = socialSimilarityMatrix.row(userIdx).iterator();
        while (friendIter.hasNext()) {
            Vector.VectorEntry userRatingEntry = friendIter.next();
            int similarUserIdx = userRatingEntry.index();
            if (!userList.contains(similarUserIdx)) {
                continue;
            }
            socialNeighborSimis.add(userRatingEntry.get());
        }
        if (socialNeighborSimis.size() == 0) {
            predictScore[1] = 0.0d;
        } else {
            double sum = 0.0d;
            for (int i = 0; i < socialNeighborSimis.size(); i++) {
                sum += socialNeighborSimis.get(i);
            }
            predictScore[1] = sum;
        }
        /*---------end social influence socre calculation--------*/

        /*---------start geo influence socre calculation--------*/
        double geoScore = 1.0d;
        int[] itemList = trainMatrix.row(userIdx).getIndices();
        if (itemList.length == 0) {
            geoScore = 0.0d;
        } else {
            for (int visitedPOI : itemList) {
                double distance = getDistance(locationCoordinates[visitedPOI].getKey(), locationCoordinates[visitedPOI].getValue(),
                        locationCoordinates[itemIdx].getKey(), locationCoordinates[itemIdx].getValue());
                if (distance < 0.01) {
                    distance = 0.01;
                }
                geoScore *= w0 * Math.pow(distance, w1);
            }
        }
        predictScore[2] = geoScore;
        /*---------end geo influence socre calculation--------*/

        return predictScore;
    }

    public void buildSocialSimilarity() {
        Table<Integer, Integer, Double> socialSimilarityTable = HashBasedTable.create();
        for (int userIdx = 0; userIdx < numUsers; userIdx++) {
            SequentialSparseVector userVector = trainMatrix.row(userIdx);
            if (userVector.getNumEntries() == 0) {
                continue;
            }
            int[] socialNeighborList = socialMatrix.column(userIdx).getIndices();
            for (int socialNeighborIdx : socialNeighborList) {
                if (userIdx < socialNeighborIdx) {
                    SequentialSparseVector socialVector = trainMatrix.row(socialNeighborIdx);
                    int[] friendList = socialMatrix.column(socialNeighborIdx).getIndices();
                    if (socialVector.getNumEntries() == 0 || friendList.length == 0) {
                        continue;
                    }
                    if (getCorrelation(userVector, socialVector) > 0.0 && getCorrelation(socialNeighborList, friendList) > 0.0) {
                        double sim = (1 - eta) * getCorrelation(userVector, socialVector) + eta * getCorrelation(socialNeighborList, friendList);
                        if (!Double.isNaN(sim) && sim != 0.0) {
                            socialSimilarityTable.put(userIdx, socialNeighborIdx, sim);
                        }
                    }
                }
            }
        }
        socialSimilarityMatrix = new SequentialAccessSparseMatrix(numUsers, numUsers, socialSimilarityTable);
    }

    /**
     * fit the "log-log" scale power law distribution
     */
    public void fitPowerLaw() {
        Map<Integer, Double> distanceMap = new HashMap<>();
        Map<Double, Double> logdistanceMap = new HashMap<>();
        int pairNum = 0;

        for (int userIdx = 0; userIdx < numUsers; userIdx++) {
            int[] itemList = trainMatrix.row(userIdx).getIndices();
            if (itemList.length == 0) {
                continue;
            }

            for (int i = 0; i < itemList.length - 1; i++) {
                for (int j = i + 1; j < itemList.length; j++) {
                    double distance = getDistance(locationCoordinates[itemList[i]].getKey(), locationCoordinates[itemList[i]].getValue(),
                            locationCoordinates[itemList[j]].getKey(), locationCoordinates[itemList[j]].getValue());
                    if ((int) distance > 0) {
                        int intDistance = (int) distance;
                        if (!distanceMap.containsKey(intDistance)) {
                            distanceMap.put(intDistance, 0.0d);
                        }
                        distanceMap.put(intDistance, distanceMap.get(intDistance) + 1.0d);
                    }
                    pairNum++;
                }
            }
        }

        for (Map.Entry<Integer, Double> distanceEntry : distanceMap.entrySet()) {
            logdistanceMap.put(Math.log10(distanceEntry.getKey()), Math.log10(distanceEntry.getValue() * 1.0 / pairNum));
        }

        /*-------start gradient descent--------*/
        w0 = Randoms.random();
        w1 = Randoms.random();
        //regularization coefficient
        double reg = 0.1;
        //learn rate
        double lrate = 0.00001;
        //max number of iterations
        int maxIterations = 2000;

        for (int i = 0; i < maxIterations; i++) {
            //gradients of w0 and w1
            double w0Gradient = 0.0d;
            double w1Gradient = 0.0d;

            for (Map.Entry<Double, Double> distanceEntry : logdistanceMap.entrySet()) {
                double distance = distanceEntry.getKey();
                double probability = distanceEntry.getValue();
                w0Gradient += (w0 + w1 * distance - probability);
                w1Gradient += (w0 + w1 * distance - probability) * distance;
            }
            w0 -= lrate * (w0Gradient + reg * w0);
            w1 -= lrate * (w1Gradient + reg * w1);
        }
        /*-------end gradient descent--------*/

        w0 = Math.pow(10, w0);
    }

    /**
     * calculate the spherical distance between location(lat1, long1) and location (lat2, long2)
     * @param lat1
     * @param long1
     * @param lat2
     * @param long2
     * @return
     */
    protected double getDistance(Double lat1, Double long1, Double lat2, Double long2) {
        if (Math.abs(lat1 - lat2) < 1e-6 && Math.abs(long1 - long2) < 1e-6) {
            return 0.0d;
        }
        double degreesToRadius = Math.PI / 180.0;
        double phi1 = (90.0 - lat1) * degreesToRadius;
        double phi2 = (90.0 - lat2) * degreesToRadius;
        double theta1 = long1 * degreesToRadius;
        double theta2 = long2 * degreesToRadius;
        double cos = (Math.sin(phi1) * Math.sin(phi2) * Math.cos(theta1 - theta2) +
                Math.cos(phi1) * Math.cos(phi2));
        double arc = Math.acos(cos);
        double earthRadius = 6371;
        return arc * earthRadius;
    }

    public double getCorrelation(SequentialSparseVector thisVector, SequentialSparseVector thatVector) {
        // compute jaccard similarity
        Set<Integer> elements = unionArrays(thisVector.getIndices(), thatVector.getIndices());
        int numAllElements = elements.size();
        int numCommonElements = thisVector.getIndices().length + thatVector.getIndices().length - numAllElements;
        return (numCommonElements + 0.0) / numAllElements;
    }

    public Set<Integer> unionArrays(int[] arr1, int[] arr2) {
        Set<Integer> set = new HashSet<>();
        for (int num : arr1) {
            set.add(num);
        }
        for (int num : arr2) {
            set.add(num);
        }
        return set;
    }

    public double getCorrelation(int[] thisList, int[] thatList) {
        // compute jaccard similarity
        Set<Integer> elements = new HashSet<Integer>();
        for (int num : thisList) {
            elements.add(num);
        }
        for (int num : thatList) {
            elements.add(num);
        }
        int numAllElements = elements.size();
        int numCommonElements = thisList.length + thatList.length
                - numAllElements;
        return (numCommonElements + 0.0) / numAllElements;
    }

    @Override
    public RecommendedList recommendRating(DataSet predictDataSet) throws LibrecException {
        return null;
    }

    @Override
    public RecommendedList recommendRating(LibrecDataList<AbstractBaseDataEntry> dataList) throws LibrecException {
        return null;
    }

    @Override
    public RecommendedList recommendRank() throws LibrecException {
        LOG.info("Eveluate for users from id 0 to id\t" + (limitUserNum-1));
        RecommendedList recommendedList = new RecommendedList(numUsers);
        for (int userIdx = 0; userIdx < numUsers; ++userIdx) {
            recommendedList.addList(new ArrayList<>());
        }

        List<Integer> userList = new ArrayList<>();
        for (int userIdx = 0; userIdx < limitUserNum; ++userIdx) {
            userList.add(userIdx);
        }

        userList.parallelStream().forEach((Integer userIdx) -> {
            List<Integer> itemList = Ints.asList(trainMatrix.row(userIdx).getIndices());
            List<KeyValue<Integer, double[]>> tempItemValueList = new ArrayList<>();
            double[] maxScore = new double[]{0.0d, 0.0d, 0.0d};
            for (int itemIdx = 0; itemIdx < numPois; ++itemIdx) {
                if (!itemList.contains(itemIdx)) {
                    double[] predictRating = predictScore(userIdx, itemIdx);
                    if (predictRating[0] >= maxScore[0]) {
                        maxScore[0] = predictRating[0];
                    }
                    if (predictRating[1] >= maxScore[1]) {
                        maxScore[1] = predictRating[1];
                    }
                    if (predictRating[2] >= maxScore[2]) {
                        maxScore[2] = predictRating[2];
                    }
                    tempItemValueList.add(new KeyValue<>(itemIdx, new double[]{predictRating[0], predictRating[1], predictRating[2]}));
                }
            }

            List<KeyValue<Integer, Double>> itemValueList = new ArrayList<>();

            //normalize scores
            for (KeyValue<Integer, double[]> entry : tempItemValueList) {
                double[] scores = entry.getValue();
                if (maxScore[0] != 0.0d) {
                    scores[0] = scores[0] / maxScore[0];
                }
                if (maxScore[1] != 0.0d) {
                    scores[1] = scores[1] / maxScore[1];
                }
                if (maxScore[2] != 0.0d) {
                    scores[2] = scores[2] / maxScore[2];
                }
                double predictRating = (1 - alpha - beta) * scores[0] + alpha * scores[1]
                        + beta * scores[2];
                itemValueList.add(new KeyValue<>(entry.getKey(), predictRating));
            }

            recommendedList.setList(userIdx, itemValueList);
            recommendedList.topNRankByIndex(userIdx, topN);
        });
        if (recommendedList.size() == 0) {
            throw new IndexOutOfBoundsException("No item is recommended, there is something error in the recommendation algorithm! Please check it!");
        }
        LOG.info("end recommendation");
        return recommendedList;
    }

    @Override
    public RecommendedList recommendRank(LibrecDataList<AbstractBaseDataEntry> dataList) throws LibrecException {
        return null;
    }

    /**
     * load social relation data
     * @param inputDataPath
     * @throws IOException
     */
    private void buildSocialMatrix(String inputDataPath) throws IOException {
        LOG.info("Now loading users' social relation data success! " + socialPath);
        Table<Integer, Integer, Double> dataTable = HashBasedTable.create();
        final List<File> files = new ArrayList<File>();
        final ArrayList<Long> fileSizeList = new ArrayList<Long>();
        SimpleFileVisitor<Path> finder = new SimpleFileVisitor<Path>() {
            @Override
            public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
                fileSizeList.add(file.toFile().length());
                files.add(file.toFile());
                return super.visitFile(file, attrs);
            }
        };
        Files.walkFileTree(Paths.get(inputDataPath), finder);
        long allFileSize = 0;
        for (Long everyFileSize : fileSizeList) {
            allFileSize = allFileSize + everyFileSize.longValue();
        }
        for (File dataFile : files) {
            FileInputStream fis = new FileInputStream(dataFile);
            FileChannel fileRead = fis.getChannel();
            ByteBuffer buffer = ByteBuffer.allocate(BSIZE);
            int len;
            String bufferLine = new String();
            byte[] bytes = new byte[BSIZE];
            while ((len = fileRead.read(buffer)) != -1) {
                buffer.flip();
                buffer.get(bytes, 0, len);
                bufferLine = bufferLine.concat(new String(bytes, 0, len)).replaceAll("\r", "\n");
                String[] bufferData = bufferLine.split("(\n)+");
                boolean isComplete = bufferLine.endsWith("\n");
                int loopLength = isComplete ? bufferData.length : bufferData.length - 1;
                for (int i = 0; i < loopLength; i++) {
                    String line = new String(bufferData[i]);
                    String[] data = line.trim().split("[ \t,]+");
                    String userA = data[0];
                    String userB = data[1];
                    Double rate = (data.length >= 3) ? Double.valueOf(data[2]) : 1.0;
                    if (this.userMappingData.containsKey(userA) && this.userMappingData.containsKey(userB)) {
                        int row = this.userMappingData.get(userA);
                        int col = this.userMappingData.get(userB);
                        dataTable.put(row, col, rate);
                        dataTable.put(col, row, rate);
                    }
                }
                if (!isComplete) {
                    bufferLine = bufferData[bufferData.length - 1];
                }
                buffer.clear();
            }
            fileRead.close();
            fis.close();
        }
        int numRows = this.userMappingData.size(), numCols = this.userMappingData.size();
        socialMatrix = new SequentialAccessSparseMatrix(numRows, numCols, dataTable);
        dataTable = null;
        LOG.info("Load users' social relation data success! " + socialPath);
    }
}

四、测试自己的推荐算法

同样,自己已经写完了以上部分,那如何去测试自己算法的好坏呢?
以上面 usg 方法为例,直接利用 RecommenderJob 函数配合 usg 算法的配置项进行运行,即可看最后的实验效果:
在这里插入图片描述
另外,对于 RecommenderJob 类,它是一个封装以上整个推荐流程的类,包括数据集处理、划分、训练、预测、评估等,只要传入相应的配置项(如上面的rec/poi/usg-test.properties,不了解的可以看我往期写的内容),指定运行的推荐算法(比如这个 usg),它就会帮你跑整个实验。

这么讲起来优点复杂,之后有空,会直接更新这个类的代码走读~让大家看的更明白。

五、一点小建议

以上的代码在源码中都有,如何读者对我上面写的东西不知所云,完全可以去看看相应的源码,配合我这里写的内容,琢磨琢磨,相信读者都会比我最初学习的快!

评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值