import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import java.io.Serializable;
public class RecommendMovie {
//创建一个得分的类并实现Serializable接口
public static class Rating implements Serializable {
private int userid;
private int movieid;
private float rating;
private long timestamp;
//无参构造方法
public Rating() {
}
//有参构造方法
public Rating(int userid, int movieid, float rating, long timestamp) {
this.userid = userid;
this.movieid = movieid;
this.rating = rating;
this.timestamp = timestamp;
}
//get方法获取userid和其他
public int getUserid() {
return userid;
}
public int getMovieid() {
return movieid;
}
public float getRating() {
return rating;
}
public long getTimestamp() {
return timestamp;
}
//重写parse方法来将字符串str转换成Rating类型
public static Rating parseRating(String str) {
//将传进来的数据进行切分获取其中的四个字段
String[] movieInfo = str.split(",");
//如果不是四个字段就抛出错误
if (movieInfo.length != 4) {
throw new IllegalArgumentException("Each line must contain 4 fields");
}
//将4个字符串字段分别进行转换
int userid = Integer.parseInt(movieInfo[0]);
int movieid = Integer.parseInt(movieInfo[1]);
float rating = Float.parseFloat(movieInfo[2]);
long timestamp = Long.parseLong(movieInfo[3]);
//返回一个Rating类型的类,供调用方使用
return new Rating(userid, movieid, rating, timestamp);
}
}
public static void main(String[] args) {
//调用spark的ml包进行协同过滤推荐算法
SparkSession spark = SparkSession.builder().master("local[*]").appName("RecommendMovie").getOrCreate();
//将测试数据转换成javaRDD并用Rating进行封装
JavaRDD<Rating> javaRDD = spark.read().textFile("C:\\Users\\13373\\Desktop\\test.data").javaRDD().map(Rating::parseRating);
//将类型转换成dataframe用dataFrame中的als进行计算
Dataset<Row> dataFrame = spark.createDataFrame(javaRDD, Rating.class);
//进行随机切分,0.8的训练数据和0.2的测试数据
Dataset<Row>[] split = dataFrame.randomSplit(new double[]{0.8, 0.2});
//训练数据
Dataset<Row> training = split[0];
//测试数据
Dataset<Row> test = split[1];
/**
*获取ALS的实例,设置最大的迭代次数和最小平方差,该对象用来训练已有数据得到模型
*
* 即数据建模
*/
ALS als = new ALS()
.setMaxIter(5)//最大迭代次数
.setRegParam(0.01)//最小平方差
.setUserCol("userid")
.setItemCol("movieid")
.setRatingCol("rating");
ALSModel fit = als.fit(training);
/**
* 对模型的测试评估
*/
fit.setColdStartStrategy("drop");
Dataset<Row> predictions = fit.transform(test);
/**
* 回归测试
* 均方根误差
*/
RegressionEvaluator evaluator = new RegressionEvaluator()
.setMetricName("rmse")
.setLabelCol("rating")
.setPredictionCol("prediction");
double rmse = evaluator.evaluate(predictions);
System.out.println("Root-mean-square error = "+rmse);
//得出10个相同用户
Dataset<Row> userCF = fit.recommendForAllUsers(10);
//需要将dataset转换成javaRDD再进行存储工作
userCF.toJavaRDD().coalesce(1).saveAsTextFile("C:\\Users\\13373\\Desktop\\itemCF.txt");
//得出10个相同商品
Dataset<Row> itemCF = fit.recommendForAllItems(10);
spark.stop();
}
}
pom文件:
<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>Aiads</groupId> <artifactId>morgan13</artifactId> <version>1.0-SNAPSHOT</version> <properties> <java.version>1.8</java.version> <junit.version>4.12</junit.version> <mysql.driver.version>5.1.38</mysql.driver.version> <slf4j.version>1.7.21</slf4j.version> <fastjson.version>1.2.11</fastjson.version> <scala.version>2.11.11</scala.version> <spark.version>2.2.0</spark.version> </properties> <dependencies> <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-mllib --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-mllib_2.11</artifactId> <version>2.2.0</version> <!--<scope>runtime</scope>--> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-core --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_2.11</artifactId> <version>2.2.0</version> </dependency> <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-sql --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-sql_2.11</artifactId> <version>2.2.0</version> </dependency> <!-- https://mvnrepository.com/artifact/org.scala-lang/scala-library --> <dependency> <groupId>org.scala-lang</groupId> <artifactId>scala-library</artifactId> <version>2.11.11</version> </dependency> </dependencies> <!--maven中pom文件的java设置--> <build> <plugins> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <version>3.7.0</version> <configuration> <source>1.8</source> <target>1.8</target> </configuration> </plugin> </plugins> </build> <profiles> <profile> <id>aiads</id> <properties> <maven.compiler.source>1.8</maven.compiler.source> <maven.compiler.target>1.8</maven.compiler.target> <maven.compiler.compilerVersion>1.8</maven.compiler.compilerVersion> <!-- <sonar.host.url>http://sonar.aiads.com</sonar.host.url> <sonar.login>5e7a06adc9654b9ee9c4a114ed8b73e2f2da6489</sonar.login> --> </properties> <repositories> <repository> <id>nexus</id> <name>local private nexus</name> <url>http://nexus.aiads.com/repository/maven-public</url> <releases> <enabled>true</enabled> </releases> <snapshots> <enabled>true</enabled> </snapshots> </repository> </repositories> <pluginRepositories> <pluginRepository> <id>nexus</id> <name>local private nexus</name> <url>http://nexus.aiads.com/repository/maven-public</url> <releases> <enabled>true</enabled> </releases> <snapshots> <enabled>true</enabled> </snapshots> </pluginRepository> </pluginRepositories> </profile> </profiles> </project>