推荐算法之Slope One Java 及 PHP实现

这两个貌似都是原作者自己写的


import java.util.*;

/**
 * Daniel Lemire A simple implementation of the weighted slope one algorithm in
 * Java for item-based collaborative filtering. Assumes Java 1.5.
 * 
 * See main function for example.
 * 
 * June 1st 2006. Revised by Marco Ponzi on March 29th 2007
 */

public class SlopeOne {

  public static void main(String args[]) {
    // this is my data base
    Map<UserId, Map<ItemId, Float>> data = new HashMap<UserId, Map<ItemId, Float>>();
    // items
    ItemId item1 = new ItemId("       candy");
    ItemId item2 = new ItemId("         dog");
    ItemId item3 = new ItemId("         cat");
    ItemId item4 = new ItemId("         war");
    ItemId item5 = new ItemId("strange food");

    mAllItems = new ItemId[] { item1, item2, item3, item4, item5 };

    // I'm going to fill it in
    HashMap<ItemId, Float> user1 = new HashMap<ItemId, Float>();
    HashMap<ItemId, Float> user2 = new HashMap<ItemId, Float>();
    HashMap<ItemId, Float> user3 = new HashMap<ItemId, Float>();
    HashMap<ItemId, Float> user4 = new HashMap<ItemId, Float>();
    user1.put(item1, 1.0f);
    user1.put(item2, 0.5f);
    user1.put(item4, 0.1f);
    data.put(new UserId("Bob"), user1);
    user2.put(item1, 1.0f);
    user2.put(item3, 0.5f);
    user2.put(item4, 0.2f);
    data.put(new UserId("Jane"), user2);
    user3.put(item1, 0.9f);
    user3.put(item2, 0.4f);
    user3.put(item3, 0.5f);
    user3.put(item4, 0.1f);
    data.put(new UserId("Jo"), user3);
    user4.put(item1, 0.1f);
    // user4.put(item2,0.4f);
    // user4.put(item3,0.5f);
    user4.put(item4, 1.0f);
    user4.put(item5, 0.4f);
    data.put(new UserId("StrangeJo"), user4);
    // next, I create my predictor engine
    SlopeOne so = new SlopeOne(data);
    System.out.println("Here's the data I have accumulated...");
    so.printData();
    // then, I'm going to test it out...
    HashMap<ItemId, Float> user = new HashMap<ItemId, Float>();
    System.out.println("Ok, now we predict...");
    user.put(item5, 0.4f);
    System.out.println("Inputting...");
    SlopeOne.print(user);
    System.out.println("Getting...");
    SlopeOne.print(so.predict(user));
    //
    user.put(item4, 0.2f);
    System.out.println("Inputting...");
    SlopeOne.print(user);
    System.out.println("Getting...");
    SlopeOne.print(so.predict(user));
  }

  Map<UserId, Map<ItemId, Float>> mData;
  Map<ItemId, Map<ItemId, Float>> mDiffMatrix;
  Map<ItemId, Map<ItemId, Integer>> mFreqMatrix;

  static ItemId[] mAllItems;

  public SlopeOne(Map<UserId, Map<ItemId, Float>> data) {
    mData = data;
    buildDiffMatrix();
  }

  /**
   * Based on existing data, and using weights, try to predict all missing
   * ratings. The trick to make this more scalable is to consider only
   * mDiffMatrix entries having a large (>1) mFreqMatrix entry.
   * 
   * It will output the prediction 0 when no prediction is possible.
   */
  public Map<ItemId, Float> predict(Map<ItemId, Float> user) {
    HashMap<ItemId, Float> predictions = new HashMap<ItemId, Float>();
    HashMap<ItemId, Integer> frequencies = new HashMap<ItemId, Integer>();
    for (ItemId j : mDiffMatrix.keySet()) {
      frequencies.put(j, 0);
      predictions.put(j, 0.0f);
    }
    for (ItemId j : user.keySet()) {
      for (ItemId k : mDiffMatrix.keySet()) {
        try {
          float newval = (mDiffMatrix.get(k).get(j).floatValue() + user.get(j)
              .floatValue()) * mFreqMatrix.get(k).get(j).intValue();
          predictions.put(k, predictions.get(k) + newval);
          frequencies.put(k, frequencies.get(k)
              + mFreqMatrix.get(k).get(j).intValue());
        } catch (NullPointerException e) {
        }
      }
    }
    HashMap<ItemId, Float> cleanpredictions = new HashMap<ItemId, Float>();
    for (ItemId j : predictions.keySet()) {
      if (frequencies.get(j) > 0) {
        cleanpredictions.put(j, predictions.get(j).floatValue()
            / frequencies.get(j).intValue());
      }
    }
    for (ItemId j : user.keySet()) {
      cleanpredictions.put(j, user.get(j));
    }
    return cleanpredictions;
  }

  /**
   * Based on existing data, and not using weights, try to predict all missing
   * ratings. The trick to make this more scalable is to consider only
   * mDiffMatrix entries having a large (>1) mFreqMatrix entry.
   */
  public Map<ItemId, Float> weightlesspredict(Map<ItemId, Float> user) {
    HashMap<ItemId, Float> predictions = new HashMap<ItemId, Float>();
    HashMap<ItemId, Integer> frequencies = new HashMap<ItemId, Integer>();
    for (ItemId j : mDiffMatrix.keySet()) {
      predictions.put(j, 0.0f);
      frequencies.put(j, 0);
    }
    for (ItemId j : user.keySet()) {
      for (ItemId k : mDiffMatrix.keySet()) {
        // System.out.println("Average diff between "+j+" and "+ k +
        // " is "+mDiffMatrix.get(k).get(j).floatValue()+" with n = "+mFreqMatrix.get(k).get(j).floatValue());
        float newval = (mDiffMatrix.get(k).get(j).floatValue() + user.get(j)
            .floatValue());
        predictions.put(k, predictions.get(k) + newval);
      }
    }
    for (ItemId j : predictions.keySet()) {
      predictions.put(j, predictions.get(j).floatValue() / user.size());
    }
    for (ItemId j : user.keySet()) {
      predictions.put(j, user.get(j));
    }
    return predictions;
  }

  public void printData() {
    for (UserId user : mData.keySet()) {
      System.out.println(user);
      print(mData.get(user));
    }
    for (int i = 0; i < mAllItems.length; i++) {
      System.out.print("\n" + mAllItems[i] + ":");
      printMatrixes(mDiffMatrix.get(mAllItems[i]),
          mFreqMatrix.get(mAllItems[i]));
    }
  }

  private void printMatrixes(Map<ItemId, Float> ratings,
      Map<ItemId, Integer> frequencies) {
    for (int j = 0; j < mAllItems.length; j++) {
      System.out.format("%10.3f", ratings.get(mAllItems[j]));
      System.out.print(" ");
      System.out.format("%10d", frequencies.get(mAllItems[j]));
    }
    System.out.println();
  }

  public static void print(Map<ItemId, Float> user) {
    for (ItemId j : user.keySet()) {
      System.out.println(" " + j + " --> " + user.get(j).floatValue());
    }
  }

  public void buildDiffMatrix() {
    mDiffMatrix = new HashMap<ItemId, Map<ItemId, Float>>();
    mFreqMatrix = new HashMap<ItemId, Map<ItemId, Integer>>();
    // first iterate through users
    for (Map<ItemId, Float> user : mData.values()) {
      // then iterate through user data
      for (Map.Entry<ItemId, Float> entry : user.entrySet()) {
        if (!mDiffMatrix.containsKey(entry.getKey())) {
          mDiffMatrix.put(entry.getKey(), new HashMap<ItemId, Float>());
          mFreqMatrix.put(entry.getKey(), new HashMap<ItemId, Integer>());
        }
        for (Map.Entry<ItemId, Float> entry2 : user.entrySet()) {
          int oldcount = 0;
          if (mFreqMatrix.get(entry.getKey()).containsKey(entry2.getKey()))
            oldcount = mFreqMatrix.get(entry.getKey()).get(entry2.getKey())
                .intValue();
          float olddiff = 0.0f;
          if (mDiffMatrix.get(entry.getKey()).containsKey(entry2.getKey()))
            olddiff = mDiffMatrix.get(entry.getKey()).get(entry2.getKey())
                .floatValue();
          float observeddiff = entry.getValue() - entry2.getValue();
          mFreqMatrix.get(entry.getKey()).put(entry2.getKey(), oldcount + 1);
          mDiffMatrix.get(entry.getKey()).put(entry2.getKey(),
              olddiff + observeddiff);
        }
      }
    }
    for (ItemId j : mDiffMatrix.keySet()) {
      for (ItemId i : mDiffMatrix.get(j).keySet()) {
        float oldvalue = mDiffMatrix.get(j).get(i).floatValue();
        int count = mFreqMatrix.get(j).get(i).intValue();
        mDiffMatrix.get(j).put(i, oldvalue / count);
      }
    }
  }
}

class UserId {
  String content;

  public UserId(String s) {
    content = s;
  }

  public int hashCode() {
    return content.hashCode();
  }

  public String toString() {
    return content;
  }
}

class ItemId {
  String content;

  public ItemId(String s) {
    content = s;
  }

  public int hashCode() {
    return content.hashCode();
  }

  public String toString() {
    return content;
  }
}

 
# This is the code in plain text out of the technical report.
#
# Daniel Lemire, Sean McGrath, Implementing a Rating-Based Item-to-Item
# Recommender System in PHP/SQL, Technical Report D-01, January 2005.
#
# http://www.ondelette.com/lemire/abstracts/TRD01.html
#
# This code is in the public domain, use at your own risks.
# It is assumed that you looked at the report and know some SQL and PHP.
#
# Daniel Lemire, February 3rd 2005
#
# First part is sample SQL code.
#########CUT HERE####################

CREATE TABLE rating (
    userID INT PRIMARY KEY,
    itemID INT NOT NULL,
    ratingValue INT NOT NULL,
    datetimestamp TIMESTAMP NOT NULL
);


CREATE TABLE dev (
  itemID1 int(11) NOT NULL default '0',
  itemID2 int(11) NOT NULL default '0',
  count int(11) NOT NULL default '0',
  sum int(11) NOT NULL default '0',
  PRIMARY KEY  (itemID1,itemID2)
);


# simple query to output 10 most liked items
# by people who rated item 1
SELECT itemID2, ( sum / count ) AS average
FROM dev
WHERE count > 2 AND itemID1 = 1
ORDER  BY ( sum / count ) DESC
LIMIT 10;

# Next part is sample PHP code.
#########CUT HERE####################

// This code assumes $itemID is set to that of 
// the item that was just rated. 
// Get all of the user's rating pairs
$sql = "SELECT DISTINCT r.itemID, r2.ratingValue - r.ratingValue 
            as rating_difference
            FROM rating r, rating r2
            WHERE r.userID=$userID AND 
                    r2.itemID=$itemID AND 
                    r2.userID=$userID;";
$db_result = mysql_query($sql, $connection);
$num_rows = mysql_num_rows($db_result);
//For every one of the user's rating pairs, 
//update the dev table
while ($row = mysql_fetch_assoc($db_result)) {
    $other_itemID = $row["itemID"];
    $rating_difference = $row["rating_difference"];
    //if the pair ($itemID, $other_itemID) is already in the dev table
    //then we want to update 2 rows.
    if (mysql_num_rows(mysql_query("SELECT itemID1 
    FROM dev WHERE itemID1=$itemID AND itemID2=$other_itemID",
    $connection)) > 0)  {
        $sql = "UPDATE dev SET count=count+1, 
	sum=sum+$rating_difference WHERE itemID1=$itemID 
	AND itemID2=$other_itemID";
        mysql_query($sql, $connection);
	//We only want to update if the items are different                
        if ($itemID != $other_itemID) {
            $sql = "UPDATE dev SET count=count+1, 
	    sum=sum-$rating_difference 
	    WHERE (itemID1=$other_itemID AND itemID2=$itemID)";
            mysql_query($sql, $connection);
        }
    }
    else { //we want to insert 2 rows into the dev table
        $sql = "INSERT INTO dev VALUES ($itemID, $other_itemID,
        1, $rating_difference)";
        mysql_query($sql, $connection); 
	//We only want to insert if the items are different       
        if ($itemID != $other_itemID) {         
            $sql = "INSERT INTO dev VALUES ($other_itemID, 
	    $itemID, 1, -$rating_difference)";
            mysql_query($sql, $connection);
        }
    }    
}


function predict($userID, $itemID) {
    global $connection;    
    $denom = 0.0; //denominator
    $numer = 0.0; //numerator    
    $k = $itemID;    
    $sql = "SELECT r.itemID, r.ratingValue 
    FROM rating r WHERE r.userID=$userID AND r.itemID <> $itemID";
    $db_result = mysql_query($sql, $connection);        
    //for all items the user has rated
    while ($row = mysql_fetch_assoc($db_result))  {
        $j = $row["itemID"];
        $ratingValue = $row["ratingValue"];        
        //get the number of times k and j have both been rated by the same user
        $sql2 = "SELECT d.count, d.sum FROM dev d WHERE itemID1=$k AND itemID2=$j";
        $count_result = mysql_query($sql2, $connection);        
        //skip the calculation if it isn't found
        if(mysql_num_rows($count_result) > 0)  {
            $count = mysql_result($count_result, 0, "count");
            $sum = mysql_result($count_result, 0, "sum");            
            //calculate the average
            $average = $sum / $count;            
            //increment denominator by count
            $denom += $count;            
            //increment the numerator
            $numer += $count * ($average + $ratingValue);
        }        
    }    
    if ($denom == 0)
        return 0;
    else
        return ($numer / $denom);
}


function predict_all($userID ) {
    $sql2 = "SELECT d.itemID1 as 'item', sum(d.count) as 'denom', 
    sum(d.sum + d.count*r.ratingValue) as 'numer' FROM rating r,
    dev d WHERE r.userID=$userID 
    AND d.itemID1 NOT IN 
    (SELECT itemID FROM rating WHERE userID=$userID)  
    AND d.itemID2=r.itemID GROUP BY d.itemID1";
    return mysql_query($sql2, $connection);
}


function predict_best($userID, $n ) {
    $sql2 = "SELECT d.itemID1 as 'item', 
    sum(d.sum + d.count*r.ratingValue)/sum(d.count) as 'avgrat' 
    FROM  rating r, dev d 
    WHERE r.userID=$userID 
    AND d.itemID1 NOT IN 
    (SELECT itemID FROM rating WHERE userID=$userID)  
    AND d.itemID2=r.itemID 
    GROUP BY d.itemID1 ORDER BY avgrat DESC LIMIT $n";
    return mysql_query($sql2, $connection);
}

转自

http://lemire.me/fr/documents/publications/webpaper.txt

http://lemire.me/fr/documents/publications/SlopeOne.java

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值