目录
2.4.1 com.qf.bigata.transformer.ItemBaseFeatureModelData
2.4.2 com/qf/bigata/transformer/ItemCFModelData.scala
2.4.3 com/qf/bigata/transformer/LRModelData.scala
2.4.4 com/qf/bigata/transformer/ModelData.scala
2.4.5 com/qf/bigata/transformer/UnionFeatureModelData.scala
2.4.6 com/qf/bigata/transformer/UserBaseFeatureModelData.scala
2.5.1 com.qf.bigata.udfs.FeatureUDF
2.6.2 com.qf.bigata.utils.HBaseUtils
2.7.2 com/qf/bigata/AlsModelData.scala
2.7.3 com.qf.bigata.ArticleEmbedding
2.7.4 com.qf.bigata.transformer.ArticleEmbeddingModelData
2.7.5 com.qf.bigata.ItemBaseFeature
2.7.9 com.qf.bigata.UserBaseFeature
2.8.2 org.jpmml.sparkml.feature.StringVectorConverter
3.2.1com.qf.bigdata.dao.impl.HBaseDaoImpl
3.2.2com.qf.bigdata.dao.impl.MilvusDaoImpl
3.2.3 com.qf.bigdata.dao.impl.PrestoDaoImpl
3.2.4com.qf.bigdata.dao.DataSourceConfig
3.2.5com.qf.bigdata.dao.HBaseConfig
3.2.6com.qf.bigdata.dao.HBaseDao
3.2.7 com.qf.bigdata.dao.MilvusConfig
3.2.8 com.qf.bigdata.dao.MilvusDao
3.2.9 com.qf.bigdata.dao.PrestoDao
3.3.1 com.qf.bigdata.pojo.DauPredictInfo
3.3.2 com.qf.bigdata.pojo.HBaseProperties
3.3.3 com.qf.bigdata.pojo.MilvusProperties
3.3.5 com.qf.bigdata.pojo.RecommendInfo
3.3.6 com.qf.bigdata.pojo.RecommendResult
3.3.7 com.qf.bigdata.pojo.RetentionCurvelInfo
3.3.8 com.qf.bigdata.pojo.Sample
3.3.9 com.qf.bigdata.pojo.UserEmbeddingInfo
3.3.10 com.qf.bigdata.pojo.UserEmbeddingResult
3.4.1com.qf.bigdata.service.impl.RecommendServiceImpl
3.4.2 com.qf.bigdata.service.impl.RetentionServiceImpl
3.4.3 com.qf.bigdata.service.impl.UserEmbeddingServiceImpl
3.4.4com.qf.bigdata.service.RecommendService
3.4.5 com.qf.bigdata.service.RetentionService
3.4.6 com.qf.bigdata.service.UserEmbeddingService
3.5.1com.qf.bigdata.utils.HBaseUtils
3.5.2 com.qf.bigdata.utils.Leastsq
3.5.3 com.qf.bigdata.utils.MilvusUtils
3.5.5 com.qf.bigdata.utils.TimeUtils
3.6.1 com.qf.bigdata.web.controller.DauController
3.6.2 com.qf.bigdata.web.controller.RecommendController
3.6.3com.qf.bigdata.web.controller.UserEmbeddingController
3.6.4 com.qf.bigdata.Application
3.6.6 com.qf.bigdata.TomcatConfig
其实这个项目字数太多了,博客都以及上升到十三万字数。类的话也有几十个类。
背景指路
项目四:使用SparkSQL开发的简易推荐系统_林柚晞的博客-CSDN博客_spark推荐系统开发案例
我摊牌了我只想躺平去多刷题了。现在我就把之前的做推荐系统的代码发一下以供参考
这里搞了两个召回策略,我不太熟悉ALS.。
0 pom.xml
<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>com.qf.bigdata</groupId> <artifactId>recommend-test</artifactId> <version>1.0-SNAPSHOT</version> <properties> <scala.version>2.11.12</scala.version> <play-json.version>2.3.9</play-json.version> <maven-scala-plugin.version>2.10.1</maven-scala-plugin.version> <scala-maven-plugin.version>3.2.0</scala-maven-plugin.version> <maven-assembly-plugin.version>2.6</maven-assembly-plugin.version> <spark.version>2.4.5</spark.version> <scope.type>compile</scope.type> <json.version>1.2.3</json.version> <hbase.version>1.3.6</hbase.version> <hadoop.version>2.8.1</hadoop.version> <!--compile provided--> </properties> <dependencies> <!--json 包--> <dependency> <groupId>com.alibaba</groupId> <artifactId>fastjson</artifactId> <version>${json.version}</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_2.11</artifactId> <version>${spark.version}</version> <scope>${scope.type}</scope> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-sql_2.11</artifactId> <version>${spark.version}</version> <scope>${scope.type}</scope> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-hive_2.11</artifactId> <version>${spark.version}</version> <scope>${scope.type}</scope> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-mllib_2.11</artifactId> <version>${spark.version}</version> <scope>${scope.type}</scope> </dependency> <dependency> <groupId>mysql</groupId> <artifactId>mysql-connector-java</artifactId> <version>5.1.28</version> </dependency> <dependency> <groupId>log4j</groupId> <artifactId>log4j</artifactId> <version>1.2.17</version> <scope>${scope.type}</scope> </dependency> <dependency> <groupId>commons-codec</groupId> <artifactId>commons-codec</artifactId> <version>1.6</version> </dependency> <dependency> <groupId>org.scala-lang</groupId> <artifactId>scala-library</artifactId> <version>${scala.version}</version> <scope>${scope.type}</scope> </dependency> <dependency> <groupId>org.scala-lang</groupId> <artifactId>scala-reflect</artifactId> <version>${scala.version}</version> <scope>${scope.type}</scope> </dependency> <dependency> <groupId>com.github.scopt</groupId> <artifactId>scopt_2.11</artifactId> <version>4.0.0-RC2</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-avro_2.11</artifactId> <version>${spark.version}</version> </dependency> <dependency> <groupId>org.apache.hive</groupId> <artifactId>hive-jdbc</artifactId> <version>2.3.7</version> <scope>${scope.type}</scope> <exclusions> <exclusion> <groupId>javax.mail</groupId> <artifactId>mail</artifactId> </exclusion> <exclusion> <groupId>org.eclipse.jetty.aggregate</groupId> <artifactId>*</artifactId> </exclusion> </exclusions> </dependency> <dependency> <groupId>org.apache.hadoop</groupId> <artifactId>hadoop-client</artifactId> <version>${hadoop.version}</version> <scope>${scope.type}</scope> </dependency> <dependency> <groupId>org.apache.hbase</groupId> <artifactId>hbase-server</artifactId> <version>${hbase.version}</version> <scope>${scope.type}</scope> </dependency> <dependency> <groupId>org.apache.hbase</groupId> <artifactId>hbase-client</artifactId> <version>${hbase.version}</version> <scope>${scope.type}</scope> </dependency> <dependency> <groupId>org.apache.hbase</groupId> <artifactId>hbase-hadoop2-compat</artifactId> <version>${hbase.version}</version> <scope>${scope.type}</scope> </dependency> <dependency> <groupId>org.jpmml</groupId> <artifactId>jpmml-sparkml</artifactId> <version>1.5.9</version> </dependency> </dependencies> <repositories> <repository> <id>alimaven</id> <url>http://maven.aliyun.com/nexus/content/groups/public/</url> <releases> <updatePolicy>never</updatePolicy> </releases> <snapshots> <updatePolicy>never</updatePolicy> </snapshots> </repository> </repositories> <build> <sourceDirectory>src/main/scala</sourceDirectory> <testSourceDirectory>src/test/</testSourceDirectory> <plugins> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-shade-plugin</artifactId> <version>3.2.4</version> <executions> <execution> <phase>package</phase> <goals> <goal>shade</goal> </goals> <configuration> <shadedArtifactAttached>true</shadedArtifactAttached> <shadedClassifierName>jar-with-dependencies</shadedClassifierName> <filters> <filter> <artifact>org.jpmml:jpmml-sparkml</artifact> <excludes> <exclude>META-INF/sparkml2pmml.properties</exclude> </excludes> </filter> <filter> <artifact>*:*</artifact> <excludes> <exclude>META-INF/*.SF</exclude> <exclude>META-INF/*.DSA</exclude> <exclude>META-INF/*.RSA</exclude> </excludes> </filter> </filters> </configuration> </execution> </executions> </plugin> <!--<plugin>--> <!--<groupId>org.apache.maven.plugins</groupId>--> <!--<artifactId>maven-assembly-plugin</artifactId>--> <!--<version>${maven-assembly-plugin.version}</version>--> <!--<configuration>--> <!--<descriptorRefs>--> <!--<descriptorRef>jar-with-dependencies</descriptorRef>--> <!--</descriptorRefs>--> <!--<filters>--> <!--<filter>--> <!--</filter>--> <!--</filters>--> <!--</configuration>--> <!--<executions>--> <!--<execution>--> <!--<id>make-assembly</id>--> <!--<phase>package</phase>--> <!--<goals>--> <!--<goal>single</goal>--> <!--</goals>--> <!--</execution>--> <!--</executions>--> <!--</plugin>--> <plugin> <groupId>net.alchim31.maven</groupId> <artifactId>scala-maven-plugin</artifactId> <version>${scala-maven-plugin.version}</version> <executions> <!-- 先编译scala,防止 cannot find symbol --> <execution> <id>scala-compile-first</id> <phase>process-resources</phase> <goals> <goal>add-source</goal> <goal>compile</goal> </goals> </execution> <execution> <goals> <goal>compile</goal> <goal>testCompile</goal> </goals> <configuration> <args> <arg>-dependencyfile</arg> <arg>${project.build.directory}/.scala_dependencies</arg> </args> </configuration> </execution> </executions> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-archetype-plugin</artifactId> <version>2.2</version> </plugin> <plugin> <groupId>org.codehaus.mojo</groupId> <artifactId>build-helper-maven-plugin</artifactId> <version>1.8</version> <executions> <!-- Add src/main/scala to eclipse build path --> <execution> <id>add-source</id> <phase>generate-sources</phase> <goals> <goal>add-source</goal> </goals> <configuration> <sources> <source>src/main/java</source> </sources> </configuration> </execution> <!-- Add src/test/scala to eclipse build path --> <execution> <id>add-test-source</id> <phase>generate-test-sources</phase> <goals> <goal>add-test-source</goal> </goals> <configuration> <sources> <source>src/test/java</source> </sources> </configuration> </execution> </executions> </plugin> </plugins> </build> </project>
大概的项目框架
架构长这样
1.0 资源
1.1 sparkml2pmml.properties
# Features org.apache.spark.ml.feature.Binarizer = org.jpmml.sparkml.feature.BinarizerConverter org.apache.spark.ml.feature.Bucketizer = org.jpmml.sparkml.feature.BucketizerConverter org.apache.spark.ml.feature.ChiSqSelectorModel = org.jpmml.sparkml.feature.ChiSqSelectorModelConverter org.apache.spark.ml.feature.ColumnPruner = org.jpmml.sparkml.feature.ColumnPrunerConverter org.apache.spark.ml.feature.CountVectorizerModel = org.jpmml.sparkml.feature.CountVectorizerModelConverter org.apache.spark.ml.feature.IDFModel = org.jpmml.sparkml.feature.IDFModelConverter org.apache.spark.ml.feature.ImputerModel = org.jpmml.sparkml.feature.ImputerModelConverter org.apache.spark.ml.feature.IndexToString = org.jpmml.sparkml.feature.IndexToStringConverter org.apache.spark.ml.feature.Interaction = org.jpmml.sparkml.feature.InteractionConverter org.apache.spark.ml.feature.MaxAbsScalerModel = org.jpmml.sparkml.feature.MaxAbsScalerModelConverter org.apache.spark.ml.feature.MinMaxScalerModel = org.jpmml.sparkml.feature.MinMaxScalerModelConverter org.apache.spark.ml.feature.NGram = org.jpmml.sparkml.feature.NGramConverter org.apache.spark.ml.feature.OneHotEncoderModel = org.jpmml.sparkml.feature.OneHotEncoderModelConverter org.apache.spark.ml.feature.PCAModel = org.jpmml.sparkml.feature.PCAModelConverter org.apache.spark.ml.feature.RegexTokenizer = org.jpmml.sparkml.feature.RegexTokenizerConverter org.apache.spark.ml.feature.RFormulaModel = org.jpmml.sparkml.feature.RFormulaModelConverter org.apache.spark.ml.feature.SQLTransformer = org.jpmml.sparkml.feature.SQLTransformerConverter org.apache.spark.ml.feature.StandardScalerModel = org.jpmml.sparkml.feature.StandardScalerModelConverter org.apache.spark.ml.feature.StringIndexerModel = org.jpmml.sparkml.feature.StringIndexerModelConverter org.apache.spark.ml.feature.StopWordsRemover = org.jpmml.sparkml.feature.StopWordsRemoverConverter org.apache.spark.ml.feature.Tokenizer = org.jpmml.sparkml.feature.TokenizerConverter org.apache.spark.ml.feature.VectorAssembler = org.jpmml.sparkml.feature.VectorAssemblerConverter org.apache.spark.ml.feature.VectorAttributeRewriter = org.jpmml.sparkml.feature.VectorAttributeRewriterConverter org.apache.spark.ml.feature.VectorIndexerModel = org.jpmml.sparkml.feature.VectorIndexerModelConverter org.apache.spark.ml.feature.VectorSizeHint = org.jpmml.sparkml.feature.VectorSizeHintConverter org.apache.spark.ml.feature.VectorSlicer = org.jpmml.sparkml.feature.VectorSlicerConverter org.apache.spark.ml.feature.StringVector = org.jpmml.sparkml.feature.StringVectorConverter # Prediction models org.apache.spark.ml.classification.DecisionTreeClassificationModel = org.jpmml.sparkml.model.DecisionTreeClassificationModelConverter org.apache.spark.ml.classification.GBTClassificationModel = org.jpmml.sparkml.model.GBTClassificationModelConverter org.apache.spark.ml.classification.LinearSVCModel = org.jpmml.sparkml.model.LinearSVCModelConverter org.apache.spark.ml.classification.LogisticRegressionModel = org.jpmml.sparkml.model.LogisticRegressionModelConverter org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel = org.jpmml.sparkml.model.MultilayerPerceptronClassificationModelConverter org.apache.spark.ml.classification.NaiveBayesModel = org.jpmml.sparkml.model.NaiveBayesModelConverter org.apache.spark.ml.classification.RandomForestClassificationModel = org.jpmml.sparkml.model.RandomForestClassificationModelConverter org.apache.spark.ml.clustering.KMeansModel = org.jpmml.sparkml.model.KMeansModelConverter org.apache.spark.ml.regression.DecisionTreeRegressionModel = org.jpmml.sparkml.model.DecisionTreeRegressionModelConverter org.apache.spark.ml.regression.GBTRegressionModel = org.jpmml.sparkml.model.GBTRegressionModelConverter org.apache.spark.ml.regression.GeneralizedLinearRegressionModel = org.jpmml.sparkml.model.GeneralizedLinearRegressionModelConverter org.apache.spark.ml.regression.LinearRegressionModel = org.jpmml.sparkml.model.LinearRegressionModelConverter org.apache.spark.ml.regression.RandomForestRegressionModel = org.jpmml.sparkml.model.RandomForestRegressionModelConverter
1.2 core-site.xml
<?xml version="1.0" encoding="UTF-8"?> <?xml-stylesheet type="text/xsl" href="configuration.xsl"?> <!-- Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. See accompanying LICENSE file. --> <!-- Put site-specific property overrides in this file. --> <configuration> <property> <!-- hdfs系统的唯一标识,scheme,ip,port ,内部守护进程的通信地址--> <name>fs.defaultFS</name> <value>hdfs://qianfeng01:8020</value> </property> <property> <name>hadoop.tmp.dir</name> <value>/usr/local/hadoop/tmp</value> </property> <property> <name>hadoop.proxyuser.root.hosts</name> <value>*</value> </property> <property> <name>hadoop.proxyuser.root.groups</name> <value>*</value> </property> </configuration>
1.3 hdfs-site.xml
<?xml version="1.0" encoding="UTF-8"?> <?xml-stylesheet type="text/xsl" href="configuration.xsl"?> <!-- Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. See accompanying LICENSE file. --> <!-- Put site-specific property overrides in this file. --> <configuration> <!-- namenode守护进程管理的元数据文件fsimage存储的位置--> <property> <name>dfs.namenode.name.dir</name> <value>file:///usr/local/hadoop/hdpdata/dfs/name</value> </property> <!-- 确定DFS数据节点应该将其块存储在本地文件系统的何处--> <property> <name>dfs.datanode.data.dir</name> <value>file:///usr/local/hadoop/hdpdata/dfs/data</value> </property> <!-- 块的副本数--> <property> <name>dfs.replication</name> <value>1</value> </property> <!-- 块的大小(128M),下面的单位是字节--> <property> <name>dfs.blocksize</name> <value>134217728</value> </property> <!-- secondarynamenode守护进程的http地址:主机名和端口号。参考守护进程布局--> <property> <name>dfs.namenode.secondary.http-address</name> <value>qianfeng01:50090</value> </property> <!-- namenode守护进程的http地址:主机名和端口号。参考守护进程布局--> <property> <name>dfs.namenode.http-address</name> <value>qianfeng01:50070</value> </property> <property> <name>dfs.namenode.name.dir</name> <value>file:///usr/local/hadoop/hdpdata/dfs/name</value> </property> <property> <name>dfs.namenode.checkpoint.dir</name> <value>file:///usr/local/hadoop/hdpdata/dfs/cname</value> </property> <property> <name>dfs.namenode.checkpoint.edits.dir</name> <value>file:///usr/local/hadoop/hdpdata/dfs/cname</value> </property> </configuration>
1.4 hive-site.xml
<?xml version="1.0" encoding="UTF-8" standalone="no"?> <?xml-stylesheet type="text/xsl" href="configuration.xsl"?><!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --> <configuration> <property> <name>javax.jdo.option.ConnectionUserName</name> <value>root</value> </property> <property> <name>javax.jdo.option.ConnectionPassword</name> <value>@Mmforu45</value> </property> <property> <name>javax.jdo.option.ConnectionURL</name> <value>jdbc:mysql://qianfeng01:3306/hive?createDatabaseIfNotExist=true</value> </property> <property> <name>javax.jdo.option.ConnectionDriverName</name> <value>com.mysql.jdbc.Driver</value> </property> <property> <name>hive.exec.scratchdir</name> <value>/tmp/hive</value> </property> <property> <name>hive.metastore.warehouse.dir</name> <value>/user/hive/warehouse</value> </property> <property> <name>hive.querylog.location</name> <value>/usr/local/hive/iotmp/root</value> </property> <property> <name>hive.downloaded.resources.dir</name> <value>/usr/local/hive/iotmp/${hive.session.id}_resources</value> </property> <property> <name>hive.server2.thrift.port</name> <value>10000</value> </property> <property> <name>hive.server2.thrift.bind.host</name> <value>192.168.10.101</value> </property> <property> <name>hive.server2.logging.operation.log.location</name> <value>/usr/local/hive/iotmp/root/operation_logs</value> </property> <property> <name>hive.metastore.uris</name> <value>thrift://192.168.10.101:9083</value> </property> <property> <name>hive.cli.print.current.db</name> <value>true</value> </property> <property> <name>hive.exec.mode.local.auto</name> <value>true</value> </property> </configuration>
1.5 yarn-site.xml
<?xml version="1.0"?> <!-- Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. See accompanying LICENSE file. --> <configuration> <!-- Site specific YARN configuration properties --> <!-- 指定yarn的shuffle技术--> <property> <name>yarn.resourcemanager.hostname</name> <value>qianfeng01</value> </property> <property> <name>yarn.nodemanager.aux-services</name> <value>mapreduce_shuffle</value> </property> <property> <name>yarn.nodemanager.vmem-check-enabled</name> <value>false</value> </property> <property> <name>yarn.resourcemanager.scheduler.class</name> <value>org.apache.hadoop.yarn.server.resourcemanager.scheduler.fair.FairScheduler</value> </property> <property> <name>yarn.scheduler.fair.preemption</name> <value>true</value> </property> <property> <name>yarn.scheduler.fair.preemption.cluster-utilization-threshold</name> <value>1.0</value> </property> </configuration>
2 scala部分的架构
2.1 conf
package com.qf.bigata.conf import org.slf4j.LoggerFactory /** * 配置类,用于规定调用jar的时候的选项的使用 */ case class Config( env:String = "", hBaseZK:String = "192.168.10.101", hBasePort:String = "2181", hFileTmpPath:String = "/tmp/hFile", tableName:String = "", irisPath:String = "", proxyUser:String = "root", topK:Int = 10 ) object Config { private val logger = LoggerFactory.getLogger(Config.getClass.getSimpleName) /** * 解析参数 * @param obj : 用于判断解析参数类的类型 * @param args : 具体的参数值 */ def parseConfig(obj: Object, args: Array[String]): Config = { //1. 获取到程序名称 val programName = obj.getClass.getSimpleName.replace("$", "") //2. 类似于getopts命令 //2.1 得到解析器 val parser = new scopt.OptionParser[Config](s"ItemCF ${programName}") { head(programName, "v1.0") opt[String]('e', "env").required().action((x, config) => config.copy(env = x)).text("dev or prod") opt[String]('x', "proxyUser").required().action((x, config) => config.copy(proxyUser = x)).text("proxy username") opt[String]('z', "hBaseZK").optional().action((x, config) => config.copy(hBaseZK = x)).text("hBaseZK") opt[String]('p', "hBasePort").optional().action((x, config) => config.copy(hBasePort = x)).text("hBasePort") opt[String]('f', "hFileTmpPath").optional().action((x, config) => config.copy(hFileTmpPath = x)).text("hFileTmpPath") opt[String]('t', "tableName").optional().action((x, config) => config.copy(tableName = x)).text("tableName") opt[Int]('k', "topK").optional().action((x, config) => config.copy(topK = x)).text("topK") programName match { case "ItemCF" => logger.info(s"ItemCF is staring ---------------------------->") case "AlsCF" => logger.info(s"AlsCF is staring ---------------------------->") case "ItemBaseFeature" => logger.info(s"ItemBaseFeature is staring ---------------------------->") case "UserBaseFeature" => logger.info(s"UserBaseFeature is staring ---------------------------->") case "ArticleEmbedding" => logger.info(s"ArticleEmbedding is staring ---------------------------->") case "LRClass" => logger.info(s"LRClass is staring ---------------------------->") case "UnionFeature" => logger.info(s"UnionFeature is staring ---------------------------->") case _ => } } //2.2 解析 parser.parse(args, Config()) match { case Some(conf) => conf case None => { logger.error(s"cannot parse args") System.exit(-1) null } } } }
2.2 Action
package com.qf.bigata.constant /** * 表示文章的五种行为的枚举类 */ object Action extends Enumeration { type Action = Value val CLICK = Value("点击") val SHARE = Value("分享") val COMMENT = Value("评论") val COLLECT = Value("收藏") val LIKE = Value("点赞") /** * 将当前枚举中的所有的枚举常量打印出来 */ def showAll = this.values.foreach(println) /** * 根据枚举常量名称查询枚举的值 */ def withNameOpt(name:String):Option[Value] = this.values.find(_.toString == name) }
2.3 Constant
package com.qf.bigata.constant //常量类:以后公共常量都可以放在此类中 object Constant { //在定义新闻文章的有效时间,表示文章在前100天内具备最大价值,超过一百天。价值就梯度下滑 //距离这个时间越远,时间价值下降越快 val ARTICLE_AGING_TIME = 100 }
2.4 transformer
2.4.1 com.qf.bigata.transformer.ItemBaseFeatureModelData
package com.qf.bigata.transformer import org.apache.hadoop.hbase.KeyValue import org.apache.hadoop.hbase.io.ImmutableBytesWritable import org.apache.hadoop.hbase.util.Bytes import org.apache.spark.sql.{DataFrame, SparkSession} import scala.collection.mutable.ListBuffer class ItemBaseFeatureModelData(spark:SparkSession, env:String) extends ModelData(spark:SparkSession, env:String){ /** * 将推荐算法的结果转换RDD * 行建:article_id * 列簇:f1 * 列明: itemBaseFeatures * 值:特征数据的向量的字符串表示形式:[文章字数,图片的数量,类型,距离天数] */ def itemBaseFeatureDF2RDD(baseFeatureDF: DataFrame) = { baseFeatureDF.rdd.sortBy(x => x.get(0).toString).flatMap(row => { //1. 原始数据 val article_id: String = row.getString(0) val features: String = row.getString(1) //2. 集合 val listBuffer = new ListBuffer[(ImmutableBytesWritable, KeyValue)] //3. 存储 val kv = new KeyValue(Bytes.toBytes(article_id), Bytes.toBytes("f1"), Bytes.toBytes("itemBaseFeatures"), Bytes.toBytes(features)) //4. 将kv添加到listBuffer listBuffer.append((new ImmutableBytesWritable(), kv)) listBuffer }) } } object ItemBaseFeatureModelData { def apply(spark: SparkSession, env: String): ItemBaseFeatureModelData = new ItemBaseFeatureModelData(spark, env) }
2.4.2 com/qf/bigata/transformer/ItemCFModelData.scala
package com.qf.bigata.transformer import com.qf.bigata.utils.HBaseUtils import org.apache.hadoop.hbase.KeyValue import org.apache.hadoop.hbase.io.ImmutableBytesWritable import org.apache.hadoop.hbase.util.Bytes import org.apache.spark.mllib.linalg.distributed.{CoordinateMatrix, MatrixEntry} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.slf4j.LoggerFactory import scala.collection.mutable.ListBuffer /** * 基于物品的协同过滤策略的模型数据类 */ class ItemCFModelData(spark:SparkSession, env:String) extends ModelData(spark:SparkSession, env:String) { /** * 将推荐算法结果转换为RDD * 需要先建立hbase的表 * 行建:uid * 列簇:f1 * 列名:itemcf * 值:推荐的分值 */ def itemcf2RDD(convertDF: DataFrame) = { convertDF.rdd.sortBy(x => x.get(0).toString).flatMap(row => { //1. 获取到原始数据值 val uid: String = row.get(0).toString /* * [(sim_aid, pre_rate), (sim_aid, pre_rate), ...] * | * sim_aid:pre_rate, sim_aid:pre_rate, ... */ val items: String = row.getAs[Seq[Row]](1).map(item => { item.getInt(0).toString + ":" + item.getDouble(1).formatted("%.4f") }).mkString(",") //2. 创建集合准备存放这个结果 val listBuffer = new ListBuffer[(ImmutableBytesWritable, KeyValue)] //3. 存放 val kv = new KeyValue(Bytes.toBytes(uid), Bytes.toBytes("f1"), Bytes.toBytes("itemcf"), Bytes.toBytes(items)) //4. 将kv添加到listBuffer listBuffer.append((new ImmutableBytesWritable(), kv)) listBuffer }) } /** * (1, 2, 11) --> (uid, (sim_aid, pre_rate)) * (1, 3, 12) * ---> (uid, [(sim_aid, pre_rate), (sim_aid, pre_rate)]) * e.g. * (1, [(2,11), (3,12), ...]) */ def recommendDataConvert(recommendDF: DataFrame) = { import spark.implicits._ recommendDF.rdd.map(row => (row.getInt(0), (row.getInt(1), row.getDouble(2)))) .groupByKey().mapValues(sp => { var seq: Seq[(Int, Double)] = Seq[(Int, Double)]() sp.foreach(tp => { seq :+= (tp._1, tp._2) }) seq.sortBy(_._2) }).toDF("uid", "recommendactions") } private val logger = LoggerFactory.getLogger(ItemCFModelData.getClass.getSimpleName) /** * 通过测试数据预测结果 */ def predictTestData(joinDF: DataFrame, test: Dataset[Row]) = { //1. 建立虚表 joinDF.createOrReplaceTempView("rate_sim") test.createOrReplaceTempView("test_data") //2. 执行sql /* * rsp:用户对于与原文中相似的文章的评分 * sim:用户对于原文章的评分 */ spark.sql( s""" |with t1 as( -- 用户对于相似文章的预测评分:预测值 |select uid, sim_aid, sum(rsp) / sum(rate) as pre_rate |from rate_sim group by uid, sim_aid |), |t2 as ( -- 用户对于原文中的评分:真实值 |select uid, aid, rate from test_data |) |select t2.*, t1.pre_rate from t2 inner join t1 on t2.aid = t1.sim_aid and t1.uid = t2.uid |where t1.pre_rate is not null |""".stripMargin) } /** * 将矩阵转换为一个Dataframe */ def simMatrix2DF(simMatrix: CoordinateMatrix) = { //1. 获取到矩阵内部的数据:RDD val transformerRDD: RDD[(String, String, Double)] = simMatrix.entries.map { case MatrixEntry(row: Long, column: Long, sim: Double) => (row.toString, column.toString, sim) } //2. rdd-->dataframe val simDF: DataFrame = spark.createDataFrame(transformerRDD).toDF("aid", "sim_aid", "sim") //3. 合并结果 simDF.union(simDF.select("aid", "sim_aid", "sim")) } /** * 将评分数据表转化为评分矩阵 * * uid aid rate uid/aid 1 2 3 * 1 1 0.8 1 0.8 0.1 * 1 2 0.1 2 0.6 * 2 1 0.6 -> 3 0.8 * 3 1 0.8 4 0.25 * 4 3 0.25 */ def rateDF2Matrix(df: DataFrame) = { //1. Row --> MatrixEntry val matrixRDD: RDD[MatrixEntry] = df.rdd.map { case Row(uid: Long, aid: Long, rate: Double) => MatrixEntry(uid, aid, rate) } //2. 返回分布式矩阵 new CoordinateMatrix(matrixRDD) } } object ItemCFModelData { def apply(spark: SparkSession, env: String): ItemCFModelData = new ItemCFModelData(spark, env) }
2.4.3 com/qf/bigata/transformer/LRModelData.scala
package com.qf.bigata.transformer import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegressionTrainingSummary} import org.apache.spark.sql.{DataFrame, SparkSession} /** * 逻辑回归的数据模型 */ class LRModelData(spark:SparkSession, env:String) extends ModelData(spark:SparkSession, env:String) { /** * 打印逻辑回归模型处理之后的结果 * @param lrModel */ def printlnSummary(lrModel: LogisticRegressionModel): Unit = { val summary: LogisticRegressionTrainingSummary = lrModel.summary //1. 获取到每个迭代目标函数的值 val history: Array[Double] = summary.objectiveHistory println("history----------->") history.zipWithIndex.foreach { case (loss, iter) => println(s"iterator: ${iter}, loss:${loss}") } //2. 打印命中率 println(s"accuracy : ${summary.accuracy}") } /** * 获取到原始的训练数据 */ def getVectorTrainingData(): DataFrame = { spark.sql( s""" |with t1 as ( -- 查询到用户对哪些文章进行了点击 |select uid, aid, label from dwb_news.user_item_training |), |t2 as ( -- 将用户的向量关联 |select |t1.*, |ubv.features as user_features |from t1 left join dwb_news.user_base_vector as ubv |on t1.uid = ubv.uid where ubv.uid is not null and ubv.features <> '' |), |t3 as ( -- 将文章的向量关联 |select |t2.*, |abv.features as article_features |from t2 left join dwb_news.article_base_vector as abv |on t2.aid = abv.article_id where abv.article_id is not null and abv.features <> '' |), |t4 as ( -- 将文章的embedding关联 |select |t3.*, |ae.article_vector as article_embedding |from t3 left join dwb_news.article_embedding as ae |on t3.aid = ae.article_id |where ae.article_id is not null and ae.article_vector <> '' |) |select |uid, |aid, |user_features, |article_features, |article_embedding, |cast(label as int) as label from t4 |""".stripMargin) } } object LRModelData { def apply(spark: SparkSession, env: String): LRModelData = new LRModelData(spark, env) }
2.4.4 com/qf/bigata/transformer/ModelData.scala
package com.qf.bigata.transformer import com.qf.bigata.udfs.RateUDF import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.slf4j.LoggerFactory /** * 所有的协同过滤模型的父类,提供通用的工具函数类 */ class ModelData(spark:SparkSession, env:String) { private val logger = LoggerFactory.getLogger(ModelData.getClass.getSimpleName) def loadSourceDataUserBaseInfos(): DataFrame = { spark.sql( s""