(一)基于 Mahout 实现 User CF
1、相似度的计算
Similarity是计算两个用户或者两个物品之间的相似度的,归结到数学上就是计算向量的距离。Mahout 中提供了基本的相似度的计算,它们都实现了UserSimilarity 这个接口,实现用户相似度的计算,包括下面这些常用的:
- PearsonCorrelationSimilarity:基于皮尔逊相关系数计算相似度
- EuclideanDistanceSimilarity:基于欧几里德距离计算相似度
- TanimotoCoefficientSimilarity:基于 Tanimoto 系数计算相似度
- UncerteredCosineSimilarity:计算 Cosine 相似度
以PearsonCorrelationSimilarity为例,做一下讲解
首先看继承关系:
public final class PearsonCorrelationSimilarity extends AbstractSimilarity
abstract class AbstractSimilarity extends AbstractItemSimilarity implements UserSimilarity
public abstract class AbstractItemSimilarity implements ItemSimilarity
//实际上AbstractSimilarity 这个抽象类实现了ItemSimilarity和UserSimilarity这两个接口
//那么作为他的子类,PearsonCorrelationSimilarity 既可以计算物品的相似度也可以计算用户的相似度
下边分析一下相似度具体是怎么实现的:
//用于计算两个用户之间的相似度
@Override
public double userSimilarity(long userID1, long userID2) throws TasteException {
DataModel dataModel = getDataModel();
PreferenceArray xPrefs = dataModel.getPreferencesFromUser(userID1);//user1的评分列表
PreferenceArray yPrefs = dataModel.getPreferencesFromUser(userID2);//user2的评分列表
int xLength = xPrefs.length();
int yLength = yPrefs.length();
if (xLength == 0 || yLength == 0) {
return Double.NaN;
}
long xIndex = xPrefs.getItemID(0);//user1的当前item的ID
long yIndex = yPrefs.getItemID(0);//user2的当前item的ID
int xPrefIndex = 0;//user1的评分列表中的当前index
int yPrefIndex = 0;//user2的评分列表中的当前index
double sumX = 0.0;
double sumX2 = 0.0;
double sumY = 0.0;
double sumY2 = 0.0;
double sumXY = 0.0;
double sumXYdiff2 = 0.0;
int count = 0;
boolean hasInferrer = inferrer != null;
boolean hasPrefTransform = prefTransform != null;
//遍历user1和user2的评分列表
while (true) {
int compare = xIndex < yIndex ? -1 : xIndex > yIndex ? 1 : 0;
if (hasInferrer || compare == 0) {
double x;
double y;
if (xIndex == yIndex) {
// Both users expressed a preference for the item
if (hasPrefTransform) {
x = prefTransform.getTransformedValue(xPrefs.get(xPrefIndex));
y = prefTransform.getTransformedValue(yPrefs.get(yPrefIndex));
} else {
x = xPrefs.getValue(xPrefIndex);
y = yPrefs.getValue(yPrefIndex);
}
} else {
// Only one user expressed a preference, but infer the other one's preference and tally
// as if the other user expressed that preference
if (compare < 0) {
// X has a value; infer Y's
x = hasPrefTransform
? prefTransform.getTransformedValue(xPrefs.get(xPrefIndex))
: xPrefs.getValue(xPrefIndex);
y = inferrer.inferPreference(userID2, xIndex);
} else {
// compare > 0
// Y has a value; infer X's
x = inferrer.inferPreference(userID1, yIndex);
y = hasPrefTransform
? prefTransform.getTransformedValue(yPrefs.get(yPrefIndex))
: yPrefs.getValue(yPrefIndex);
}
}
//下边这一大堆乱七八糟的东西都是计算向量距离需要用的一些变量
sumXY += x * y;
sumX += x;
sumX2 += x * x;
sumY += y;
sumY2 += y * y;
double diff = x - y;
sumXYdiff2 += diff * diff;
count++;
}
if (compare <= 0) {
if (++xPrefIndex >= xLength) {//user1评分列表的index加一
if (hasInferrer) {
// Must count other Ys; pretend next X is far away
if (yIndex == Long.MAX_VALUE) {
// ... but stop if both are done!
break;
}
xIndex = Long.MAX_VALUE;
} else {
break;
}
} else {
xIndex = xPrefs.getItemID(xPrefIndex);//获取user1评分列表中当前的item的ID
}
}
if (compare >= 0) {
if (++yPrefIndex >= yLength) {//user2评分列表的index加一
if (hasInferrer) {
// Must count other Xs; pretend next Y is far away
if (xIndex == Long.MAX_VALUE) {
// ... but stop if both are done!
break;
}
yIndex = Long.MAX_VALUE;
} else {
break;
}
} else {
yIndex = yPrefs.getItemID(yPrefIndex);//获取user2评分列表中当前的item的ID
}
}
}
// "Center" the data. If my math is correct, this'll do it.
double result;
if (centerData) {
double meanX = sumX / count;
double meanY = sumY / count;
// double centeredSumXY = sumXY - meanY * sumX - meanX * sumY + n * meanX * meanY;
double centeredSumXY = sumXY - meanY * sumX;
// double centeredSumX2 = sumX2 - 2.0 * meanX * sumX + n * meanX * meanX;
double centeredSumX2 = sumX2 - meanX * sumX;
// double centeredSumY2 = sumY2 - 2.0 * meanY * sumY + n * meanY * meanY;
double centeredSumY2 = sumY2 - meanY * sumY;
result = computeResult(count, centeredSumXY, centeredSumX2, centeredSumY2, sumXYdiff2);
} else {
//这个computeResult()函数是具体计算相似度的abstract函数
//那么AbstractSimilarity的子类去实现这个函数
result = computeResult(count, sumXY, sumX2, sumY2, sumXYdiff2);
}
if (similarityTransform != null) {
result = similarityTransform.transformSimilarity(userID1, userID2, result);
}
if (!Double.isNaN(result)) {
result = normalizeWeightResult(result, count, cachedNumItems);
}
return result;
}
物品之间的相似度计算跟用户的基本上大同小异,就不写了。
2、邻居用户
根据建立的相似度计算方法,找到邻居用户。这里找邻居用户的方法,包括两种:“固定数量的邻居”和“相似度门槛邻居”计算方法,Mahout 提供对应的实现:
- NearestNUserNeighborhood:对每个用户取固定数量 N 的最近邻居
- ThresholdUserNeighborhood:对每个用户基于一定的限制,取落在相似度门限内的所有用户为邻居
下边以NearestNUserNeighborhood为例,看一下固定数量的最近邻居是怎么获取的:
NearestNUserNeighborhood类中的函数:
@Override
public long[] getUserNeighborhood(long userID) throws TasteException {
DataModel dataModel = getDataModel();//得到数据源
UserSimilarity userSimilarityImpl = getUserSimilarity();//计算相似度的方法
TopItems.Estimator<Long> estimator = new Estimator(userSimilarityImpl, userID, minSimilarity);
//SamplingLongPrimitiveIterator的作用是看是否要求取样,取样率<1.0的时候,对所有用户进行取样
LongPrimitiveIterator userIDs = SamplingLongPrimitiveIterator.maybeWrapIterator(dataModel.getUserIDs(),
getSamplingRate());
return TopItems.getTopUsers(n, userIDs, null, estimator);
}
//这里也实现了一个评估器Estimator类:
private static final class Estimator implements TopItems.Estimator<Long> {
//主要功能函数,计算两用户的相似度,可以设置一个阀值(最小的,大于此值才要)
@Override
public double estimate(Long userID) throws TasteException {
if (userID == theUserID) {
return Double.NaN;
}
//计算userID用户与我们这个特定的user的相似度
double sim = userSimilarityImpl.userSimilarity(theUserID, userID);
return sim >= minSim ? sim : Double.NaN;
}
}
//最后获取邻居的最重要的实现部分:
TopItems.getTopUsers(n, userIDs, null, estimator)
public static long[] getTopUsers(int howMany,
LongPrimitiveIterator allUserIDs,
IDRescorer rescorer,
Estimator<Long> estimator) throws TasteException {
Queue<SimilarUser> topUsers = new PriorityQueue<SimilarUser>(howMany + 1, Collections.reverseOrder());
boolean full = false;
double lowestTopValue = Double.NEGATIVE_INFINITY;
while (allUserIDs.hasNext()) {
long userID = allUserIDs.next();
if (rescorer != null && rescorer.isFiltered(userID)) {
continue;
}
double similarity;
try {
//得到相似度
similarity = estimator.estimate(userID);
} catch (NoSuchUserException nsue) {
continue;
}
double rescoredSimilarity = rescorer == null ? similarity : rescorer.rescore(userID, similarity);
if (!Double.isNaN(rescoredSimilarity) && (!full || rescoredSimilarity > lowestTopValue)) {
//将该用户插入到优先队列中,就是个最小堆
topUsers.add(new SimilarUser(userID, rescoredSimilarity));
if (full) {
topUsers.poll();
} else if (topUsers.size() > howMany) {
full = true;
topUsers.poll();
}
lowestTopValue = topUsers.peek().getSimilarity();
}
}
int size = topUsers.size();
if (size == 0) {
return NO_IDS;
}
List<SimilarUser> sorted = Lists.newArrayListWithCapacity(size);
sorted.addAll(topUsers);
Collections.sort(sorted);
long[] result = new long[size];
int i = 0;
for (SimilarUser similarUser : sorted) {
result[i++] = similarUser.getUserID();
}
return result;
}