/* 令currentIteration = 0*/
for (int currentIteration = 0; currentIteration < numIterations; currentIteration++) {
/* broadcast M, read A row-wise, recompute U row-wise */
log.info("Recomputing U (iteration {}/{})", currentIteration, numIterations);
/* 通过M--1计算出U-0
* 输出为[user1, {feature0:0.8817690514198447,feature1:-0.21707282987696083,...,feature19:0.23423786158394766}, ...]
*/
runSolver(pathToUserRatings(), pathToU(currentIteration), pathToM(currentIteration - 1), currentIteration, "U",
numItems);
/* broadcast U, read A' row-wise, recompute M row-wise */
log.info("Recomputing M (iteration {}/{})", currentIteration, numIterations);
/* 通过U-0计算出M-0*/
runSolver(pathToItemRatings(), pathToM(currentIteration), pathToU(currentIteration), currentIteration, "M",
numUsers);
}
/* 例如:通过M--1计算出U-0,以下代码基于此分析,其他情况请类推*/
private void runSolver(Path ratings, Path output, Path pathToUorM, int currentIteration, String matrixName,
int numEntities) throws ClassNotFoundException, IOException, InterruptedException {
// necessary for local execution in the same JVM only
SharingMapper.reset();
int iterationNumber = currentIteration + 1;
Class<? extends Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable>> solverMapperClassInternal;
String name;
/* implicitFeedback = false */
if (implicitFeedback) {
solverMapperClassInternal = SolveImplicitFeedbackMapper.class;
name = "Recompute " + matrixName + ", iteration (" + (iterationNumber + 1) + '/' + numIterations + "), "
+ '(' + numThreadsPerSolver + " threads, " + numFeatures + " features, implicit feedback)";
} else {
solverMapperClassInternal = SolveExplicitFeedbackMapper.class;
name = "Recompute " + matrixName + ", iteration (" + (iterationNumber + 1) + '/' + numIterations + "), "
+ '(' + numThreadsPerSolver + " threads, " + numFeatures + " features, explicit feedback)";
}
/* MultithreadedSharingMapper是一个多线程执行Mapper Class的Job */
Job solverForUorI = prepareJob(ratings, output, SequenceFileInputFormat.class, MultithreadedSharingMapper.class,
IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, name);
Configuration solverConf = solverForUorI.getConfiguration();
solverConf.set(LAMBDA, String.valueOf(lambda));
solverConf.set(ALPHA, String.valueOf(alpha));
solverConf.setInt(NUM_FEATURES, numFeatures);
solverConf.set(NUM_ENTITIES, String.valueOf(numEntities));
FileSystem fs = FileSystem.get(pathToUorM.toUri(), solverConf);
FileStatus[] parts = fs.listStatus(pathToUorM, PathFilters.partFilter());
for (FileStatus part : parts) {
if (log.isDebugEnabled()) {
log.debug("Adding {} to distributed cache", part.getPath().toString());
}
/* cache file 为mahout-workdir-factorize-movielens/als/tmp/M--1/part-m-00000*/
DistributedCache.addCacheFile(part.getPath().toUri(), solverConf);
}
MultithreadedMapper.setMapperClass(solverForUorI, solverMapperClassInternal);
MultithreadedMapper.setNumberOfThreads(solverForUorI, numThreadsPerSolver);
boolean succeeded = solverForUorI.waitForCompletion(true);
if (!succeeded) {
throw new IllegalStateException("Job failed!");
}
}
/* 这就是上面SolveExplicitFeedbackMapper对应的map方法
* userOrItemID 这儿是user id
* ratingsWritable 为{item1:5.0,item48:5.0,item150:5.0,item260:4.0,item527:5.0,...}
*/
@Override
protected void map(IntWritable userOrItemID, VectorWritable ratingsWritable, Context ctx)
throws IOException, InterruptedException {
/* 通过上面的cache file,这儿uOrM为M的HashMap
* [item1->{feature0:4.158833063209069,feature1:0.8746388048206997,feature2:0.7253548667517087,...,feature19:0.8728730662850301},
* item2->{feature0:3.2028985507246315,feature1:0.289635256964619,feature2:0.46411961454360107,...,feature19:0.6068967014493216},
* ..., item3695->{...}]
*/
OpenIntObjectHashMap uOrM = getSharedInstance();
/* 计算当前user id对应的ui feature集合(U矩阵的第i行),*/
uiOrmj.set(ALS.solveExplicit(ratingsWritable, uOrM, lambda, numFeatures));
ctx.write(userOrItemID, uiOrmj);
}
public static Vector solveExplicit(VectorWritable ratingsWritable, OpenIntObjectHashMap uOrM,
double lambda, int numFeatures) {
Vector ratings = ratingsWritable.get();
List featureVectors = Lists.newArrayListWithCapacity(ratings.getNumNondefaultElements());
for (Vector.Element e : ratings.nonZeroes()) {
int index = e.index();
featureVectors.add(uOrM.get(index));
}
/* 用户i的ratings为{item1:5.0,item48:5.0,...,item3408:4.0,}
* 用户i获取到对应的item featureVectors为
* [item1->{0:4.158833063209069,1:0.8746388048206997,...,19:0.8728730662850301},
* item48->{0:3.0174418604651128,1:0.05897693253591574,...,19:0.9637219102684911},
* ...
* item3408->{0:3.8755221386800365,1:0.9981447681344258,...,19:0.11514498636620973}]
* lambda为0.065,numFeatures为20
*/
return AlternatingLeastSquaresSolver.solve(featureVectors, ratings, lambda, numFeatures);
}
/* 这就是AlternatingLeastSquaresSolver的solve方法*/
public static Vector solve(Iterable featureVectors, Vector ratingVector, double lambda, int numFeatures) {
Preconditions.checkNotNull(featureVectors, "Feature vectors cannot be null");
Preconditions.checkArgument(!Iterables.isEmpty(featureVectors));
Preconditions.checkNotNull(ratingVector, "rating vector cannot be null");
Preconditions.checkArgument(ratingVector.getNumNondefaultElements() > 0, "Rating vector cannot be empty");
Preconditions.checkArgument(Iterables.size(featureVectors) == ratingVector.getNumNondefaultElements());
/* nui = 48 */
int nui = ratingVector.getNumNondefaultElements();
/* MiIi为用户i对他所rating的每个item的feature矩阵,为20x48
* {
* feature0 => {item0:4.158833063209069,item1:3.0174418604651128,...,item47:3.8755221386800365}
* feature1 => {item0:0.8746388048206997,item1:0.05897693253591574,...,item47:0.9981447681344258}
* ...
* feature19 => {item0:0.8728730662850301,item1:0.9637219102684911,...,item47:0.11514498636620973}
* }
*/
Matrix MiIi = createMiIi(featureVectors, numFeatures);
/* RiIiMaybeTransposed为用户i对他所rating的每个item的实际值矩阵,为48x1
* {
* item0 => {feature0:5.0}
* item1 => {feature0:5.0}
* ...
* item47 => {feature0:4.0}
* }
*/
Matrix RiIiMaybeTransposed = createRiIiMaybeTransposed(ratingVector);
/* compute Ai = MiIi * t(MiIi) + lambda * nui * E */
/* Ai为featurexfeature矩阵(20x20)
* {
* 0 => {0:735.9857104327824, 1:102.81718116978466,...,18:95.69797654994501, 19:89.241608292493 }
* 1 => {0:102.81718116978466,1:21.631573826528825,...,18:12.748809494518715,19:11.550196616709504}
* ...
* 18 => {0:95.69797654994501, 1:12.748809494518715,...,18:19.59090222332833, 19:11.978329842950425}
* 19 => {0:89.241608292493, 1:11.550196616709504,...,18:11.978329842950425,19:19.040173127545934}
* }
*/
Matrix Ai = miTimesMiTransposePlusLambdaTimesNuiTimesE(MiIi, lambda, nui);
/* compute Vi = MiIi * t(R(i,Ii)) */
/* Vi为用户i对他所rating的每个item的(feature矩阵x实际值矩阵)的值,20x1
* {
* 0 => {0:778.765601112045}
* 1 => {0:106.72208299946884}
* ...
* 18 => {0:102.131821661106}
* 19 => {0:97.00844504677671}
* }
*/
Matrix Vi = MiIi.times(RiIiMaybeTransposed);
/* compute Ai * ui = Vi */
/* 做QR分解就可以求出用户i的feature集合,20x1
* {0:0.8817690514198447,1:-0.21707282987696083,...,19:0.23423786158394766}
*/
return solve(Ai, Vi);
}
private static Vector solve(Matrix Ai, Matrix Vi) {
return new QRDecomposition(Ai).solve(Vi).viewColumn(0);
} |