Spark MLlib之协同过滤实例:
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import scala.Tuple2;
public class SparkMLlibColbFilter {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example");
JavaSparkContext sc = new JavaSparkContext(conf);
// Load and parse the data
String path = "file:///data/hadoop/spark-2.0.0-bin-hadoop2.7/data/mllib/als/test.data";
JavaRDD data = sc.textFile(path);
JavaRDD ratings = data.map(new Function() {
@Override
public Rating call(String s) throws Exception {
String[] sarray = s.split(",");
return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2]));
}
});
// Build the recommendation model using ALS
int rank = 10;
int numIterations = 10;
MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01);
JavaRDD> userProducts = ratings.map(new Function>() {
@Override
public Tuple2 call(Rating r) throws Exception {
return new Tuple2(r.user(), r.product());
}
});
JavaPairRDD, Double> predictions = JavaPairRDD.fromJavaRDD(
model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map(
new Function, Double>>() {
@Override
public Tuple2, Double> call(
Rating r) throws Exception {
return new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating());
}
}));
JavaRDD> ratesAndPreds = JavaPairRDD.fromJavaRDD(ratings.map(
new Function, Double>>() {
@Override
public Tuple2, Double> call(
Rating r) throws Exception {
return new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating());
}
})).join(predictions).values();
double MSE = JavaDoubleRDD.fromRDD(ratesAndPreds.map(new Function, Object>() {
@Override
public Object call(Tuple2 pair) throws Exception {
return Math.pow((pair._1() - pair._2()),2);
}
}).rdd()).mean();
System.out.println("Mean Squared Error = " + MSE);
// Save and load model
model.save(sc.sc(), "target/tmp/myCollaborativeFilter");
MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(sc.sc(),
"target/tmp/myCollaborativeFilter");
//为每个用户进行推荐,推荐的结果可以以用户id为key,结果为value存入redis或者hbase中
List users = data.map(new Function() {
@Override
public String call(String s) throws Exception {
String[] sarray = s.split(",");
return sarray[0];
}
}).distinct().collect();
for (String user : users) {
Rating[] rs = model.recommendProducts(Integer.parseInt(user), numIterations);
String value = "";
int key = 0;
for (Rating r : rs) {
key = r.user();
value = value + r.product() + ":" + r.rating() + "," ;
}
System.out.println(key + " " + value);
}
}
}
协同过滤ALS算法推荐过程如下:
加载数据到 ratings RDD,每行记录包括:user, product, rate
从 ratings 得到用户商品的数据集:(user, product)
使用ALS对 ratings 进行训练
通过 model 对用户商品进行预测评分:((user, product), rate)
从 ratings 得到用户商品的实际评分:((user, product), rate)
合并预测评分和实际评分的两个数据集,并求均方差