ALS分解实现

项目github地址:bitcarmanlee easy-algorithm-interview-and-practice
欢迎大家star,留言,一起学习进步

1.ALS简介

在前面相关的文章中,已经详细介绍了ALS的原理。用最简单的一句话总结就是:ALS是通过将user与item分别表示为一个低维稠密向量来进行后续的使用。

2.基于spark的ALS实现

首先看一部分辅助代码

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.Set;
import java.util.TreeSet;

/**
 * Created by WangLei on 20-1-10.
 */
public class TimeUtils {

    private static final Logger LOGGER = LoggerFactory.getLogger(TimeUtils.class);

    public static final String DATE_FORMAT = "yyyyMMdd";
    public static final String TIME_FORMAT = "yyyyMMdd HH:mm:ss";
    public static final String HOUR_TIME_FORMAT = "yyyyMMdd HH";

    public static final long TIME_DAY_MILLISECOND = 86400000;

    /**
     * timestamp -> ymd
     * @param timestamp
     * @return
     */
    public static String timestamp2Ymd(long timestamp) {
        String format = "yyyyMMdd";
        return timestamp2Ymd(timestamp, format);
    }

    public static String timestamp2Ymd(long timestamp, String format) {
        SimpleDateFormat sdf;
        try {
            //支持输入10位的时间戳
            if(String.valueOf(timestamp).length() == 10) {
                timestamp *= 1000;
            }
            sdf = new SimpleDateFormat(format);
            return sdf.format(new Date(timestamp));
        } catch(Exception ex) {
            sdf = new SimpleDateFormat(DATE_FORMAT);
            try {
                return sdf.format(new Date(timestamp));
            } catch (Exception e){}
        }
        return null;
    }

    public static String timestamp2Hour(long timestamp) {
        String time = timestamp2Ymd(timestamp, TIME_FORMAT);
        return time.substring(9, 11);
    }

    /**
     * ymd -> Date
     * @param ymd
     * @return
     */
    public static Date ymd2Date(String ymd) {
        return ymd2Date(ymd, "yyyyMMdd");
    }

    public static Date ymd2Date(String ymd, String format) {
        try {
            SimpleDateFormat sdf = new SimpleDateFormat(format);
            return sdf.parse(ymd);
        } catch(ParseException ex) {
            LOGGER.error("parse ymd to timestamp error!", ex);
        } catch (Exception ex) {
            LOGGER.error("there is some problem when transfer ymd2Date!", ex);
        }
        return null;
    }

    /**
     * ymd -> timestamp
     * @param ymd
     * @return
     */
    public static long ymd2timestamp(String ymd) {
        return ymd2Date(ymd).getTime();
    }


    public static String genLastDayStr() {
        return timestamp2Ymd(System.currentTimeMillis() + TIME_DAY_MILLISECOND * (-1));
    }


    /**
     * get the datestr before or after the given datestr
     * attention transfer the num from int to long
     * @param ymd
     * @param num
     * @return
     */
    public static String genDateAfterInterval(String ymd, int num) {
        long timestamp = ymd2timestamp(ymd);
        long resTimeStamp = timestamp + TIME_DAY_MILLISECOND * Long.valueOf(num);
        return timestamp2Ymd(resTimeStamp);
    }

    public static String genLastDayStr(String ymd) {
        return genDateAfterInterval(ymd, -1);
    }


    public static Set<String> genYmdSet(long beginTs, long endTs) {
        TreeSet ymdSet = new TreeSet();
        for(long ts = beginTs; ts <= endTs; ts += 86400000L) {
            ymdSet.add(timestamp2Ymd(ts));
        }
        return ymdSet;
    }

    public static Set<String> genYmdSet(String beginYmd, String endYmd) {
        long beginTs = ymd2timestamp(beginYmd);
        long endTs = ymd2timestamp(endYmd);
        return genYmdSet(beginTs, endTs);
    }

    /**
     * end between begin days
     * if begin or end is not number format or end < begin, return Integer.MIN_VALUE
     * @param begin
     * @param end
     * @return
     */
    public static int getIntervalBetweenTwoDays(String begin, String end) {
        try {
            int begintmp = Integer.valueOf(begin), endtmp = Integer.valueOf(end);
            if(begintmp > endtmp) {
                LOGGER.error("we need end no smaller than end!");
                return Integer.MIN_VALUE;
            }
            Date d1 = ymd2Date(begin);
            Date d2 = ymd2Date(end);
            Long mils = (d2.getTime() - d1.getTime()) / TIME_DAY_MILLISECOND;
            return mils.intValue();
        } catch (NumberFormatException numformatex) {
            numformatex.printStackTrace();
            return Integer.MIN_VALUE;
        }
    }
}

里面包含了很多时间的处理方法,可以直接加入代码库。

/**
  * Created by WangLei on 20-1-13.
  */
object DateSpec extends Enumeration {
    type DateSpec = Value

    val YMD , Y_M_D, YMD2 = Value
}

HDFS相关的工具类

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
import org.joda.time.DateTime

/**
  * Created by WangLei on 20-1-10.
  */
object HDFSUtils {

    val conf = new Configuration()

    def delete(sc: SparkContext, path: String) = {
        FileSystem.get(sc.hadoopConfiguration).delete(new Path(path), true)
    }

    def isExist(sc: SparkContext, path: String) = {
        FileSystem.get(sc.hadoopConfiguration).exists(new Path(path))
    }

    def checkFileExist(conf: Configuration = conf, FileName: String): Boolean = {
        var isExist = false

        try {
            val hdfs = FileSystem.get(conf)
            val path = new Path(FileName)
            isExist = hdfs.exists(path)
        } catch {
            case e: Exception => e.printStackTrace()
        }

        isExist
    }

    def latestMidPath(conf: Configuration, basePath: String): Option[String] = {
        val today = new Date
        latestMidPath(conf, basePath, new DateTime(today.getTime), 7)
    }

    def latestMidPath(conf: Configuration, basePath: String, ymd: String) : Option[String] = {
        val timestamp = TimeUtils.ymd2timestamp(ymd)
        latestMidPath(conf, basePath, new DateTime(timestamp), 7, false, DateSpec.YMD2)
    }

    def latestMidPath(conf: Configuration = conf, basePath: String, date: DateTime, limit: Int,with_success_file:Boolean = true,dateSpec: DateSpec = DateSpec.YMD): Option[String] = {
        for (i <- 0 to limit) {
            val day = date.minusDays(i)
            val path = dateSpec match {
                case DateSpec.YMD => basePath + "/date=%04d%02d%02d".format(day.getYear, day.getMonthOfYear, day.getDayOfMonth)
                case DateSpec.Y_M_D => basePath + "/year=%04d/month=%02d/day=%02d".format(day.getYear, day.getMonthOfYear, day.getDayOfMonth)
                case DateSpec.YMD2 => basePath + "/%04d%02d%02d".format(day.getYear, day.getMonthOfYear, day.getDayOfMonth)
            }

            if (checkFileExist(conf, if(with_success_file) path + "/_SUCCESS" else path))
                return Some(path)
        }
        None
    }

}

ALS训练相关代码

import org.apache.spark.SparkConf
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.recommendation.ALS
import org.apache.spark.sql.SparkSession
import org.slf4j.LoggerFactory

import scala.collection.JavaConversions._

object AlsTraining {
	
	val logger = LoggerFactory.getLogger(this.getClass)
	val separator = "\t"
	
	def genUserItemRdd(spark: SparkSession, ymd: String) = {
		val baseinput = PathUtils.user_item_click_path
		val (yesterday, daybegin) = (TimeUtils.genLastDayStr(ymd), TimeUtils.genDateAfterInterval(ymd, -29))
		val days = TimeUtils.genYmdSet(daybegin, yesterday)
		
		// userid itemid clicknum
		var rdd = spark.sparkContext.textFile(baseinput + ymd)
			.map(x => {
				val l = x.split("\t")
				(l(0), l(1), l(2))
			})
		
		for (day <- days) {
			val path = baseinput + day
			if (HDFSUtils.isExist(spark.sparkContext, path)) {
				val tmp = spark.sparkContext.textFile(path)
					.map(x => {
						val l = x.split("\t")
						(l(0), l(1), l(2))
					})
				rdd = rdd.union(tmp)
			}
		}
		rdd.cache
	}
	
	def genUserItemIndex(spark: SparkSession, ymd: String) = {
		val rdd = genUserItemRdd(spark, ymd)
		val userindex = rdd.map(x => x._1).distinct().sortBy(x => x).zipWithIndex().map(x => (x._1, x._2 + 1))
		val itemindex = rdd.map(x => x._2).distinct().sortBy(x => x).zipWithIndex().map(x => (x._1, x._2 + 1))
		
		(userindex, itemindex)
	}
	
	case class Rating(userid: Int, itemid: Int, rating: Float)
	
	def trainmodel(spark: SparkSession, ymd: String) = {
		import spark.implicits._
		
		val rdd = genUserItemRdd(spark, ymd)
		
		val userindexrdd = rdd.map(x => x._1).distinct().sortBy(x => x).zipWithIndex().map(x => (x._1, x._2 + 1))
		val itemindexrdd = rdd.map(x => x._2).distinct().sortBy(x => x).zipWithIndex().map(x => (x._1, x._2 + 1))
		
		val data = rdd.map(x => {
			val (userid, itemid, count) = (x._1, x._2, x._3.toInt)
			(userid + separator + itemid, count)
		})
			.reduceByKey(_ + _)
			.map(x => {
				val (userid, itemid, count) = (x._1.split(separator)(0), x._1.split(separator)(1), x._2)
				(userid, itemid + separator + count)
			})
			.join(userindexrdd)
			.map(x => {
				val (itemandcount, userindex) = (x._2._1, x._2._2)
				val (itemid, count) = (itemandcount.split(separator)(0), itemandcount.split(separator)(1))
				(itemid, userindex + separator + count)
			})
			.join(itemindexrdd)
			.map(x => {
				val (userandcount, itemindex) = (x._2._1, x._2._2)
				val (userindex, count) = (userandcount.split(separator)(0), userandcount.split(separator)(1))
				Rating(userindex.toInt, itemindex.toInt, count.toFloat)
			}).toDF()
		
		val Array(training, test) = data.randomSplit(Array(0.8, 0.2))
		val als = new ALS().setRank(128).setMaxIter(8).setRegParam(0.01).
			setUserCol("userid").setItemCol("itemid").setRatingCol("rating")
		val model = als.fit(training)
		
		model.setColdStartStrategy("drop")
		
		val predictions = model.transform(test)
		val evaluator = new RegressionEvaluator()
			.setMetricName("rmse")
			.setLabelCol("rating")
			.setPredictionCol("prediction")
		val rmse = evaluator.evaluate(predictions)
		
		logger.error("root-mean-square error is: {}", rmse)
		
		val userindex2userid = userindexrdd.map(x => (x._2, x._1))
		val userfactors = model.userFactors.rdd.map(x => {
			val (userid, userfactor) = (x.getInt(0).toLong, x.getList(1).toArray().mkString(","))
			(userid, userfactor)
		}).join(userindex2userid)
    		.map(x => {
				val (userindex, userfactor, userid) = (x._1, x._2._1, x._2._2)
				(userindex, userid, userfactor)
			})
    		.repartition(1)
    		.sortBy(x => x._1)
    		.map(x => "%s\t%s\t%s".format(x._1, x._2, x._3))
		
		val itemindex2itemid = itemindexrdd.map(x => (x._2, x._1))
		val itemfactors = model.itemFactors.rdd.map(x => {
			val (itemid, itemfactor) = (x.getInt(0).toLong, x.getList(1).toArray().mkString(","))
			(itemid, itemfactor)
		}).join(itemindex2itemid)
    		.map(x => {
				val (itemindex, itemfactor, itemid) = (x._1, x._2._1, x._2._2)
				(itemindex, itemid, itemfactor)
			})
    		.repartition(1)
    		.sortBy(x => x._1)
			.map(x => "%s\t%s\t%s".format(x._1, x._2, x._3))
		
		(userfactors, itemfactors)
	}
	
	def main(args: Array[String]): Unit = {
		val (ymd, operation) = (args(0), args(1))
		val sparkConf = new SparkConf()
		sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
		sparkConf.setAppName("user-item-als-training" + ymd)
		
		val spark = SparkSession.builder().config(sparkConf).getOrCreate()
		
		operation match {
			case "index" => {
				val useroutput = PathUtils.user_index_path + ymd
				val itemoutput = PathUtils.item_index_path + ymd
				val (userindex, itemindex) = genUserItemIndex(spark, ymd)
				userindex.repartition(1).sortBy(_._2).map(x => "%s\t%s".format(x._2, x._2)).saveAsTextFile(useroutput)
				itemindex.repartition(1).sortBy(_._2).map(x => "%s\t%s".format(x._2, x._2)).saveAsTextFile(itemoutput)
			}
			case "model" => {
				val (userfactors, itemfactors) = trainmodel(spark, ymd)
				val user_embedding_path = PathUtils.user_factor_path + ymd
				val item_embedding_path = PathUtils.item_factor_path + ymd
				HDFSUtils.delete(spark.sparkContext, user_embedding_path)
				HDFSUtils.delete(spark.sparkContext, item_embedding_path)
				
				userfactors.saveAsTextFile(user_embedding_path)
				itemfactors.saveAsTextFile(item_embedding_path)
			}
		}
		spark.stop()
	}
}

3.代码分析

PathUtils.user_item_click_path

这个是输入的数据集,包含三个字段:userid, itemid, 点击数

genUserItemIndex

这个方法是针对userid与itemid进行编码,注意是分开编码

trainmodel

这个方法的具体步骤如下:

1.构造训练集
2.得到als对象
3.训练模型 als.fit
4.根据得到的模型进行预测
5.分别得到user向量与item向量。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
CP分解(Canonical Polyadic Decomposition),也称为PARAFAC分解,是一种常用的高维数据分解方法,用于将一个高维张量分解为一组低秩矩阵的乘积形式。在MATLAB中,可以使用Tensor Toolbox或MATLAB自带的函数进行CP分解。 1. 使用Tensor Toolbox进行CP分解: Tensor Toolbox是MATLAB中用于处理高维数据和张量计算的工具包。以下是使用Tensor Toolbox进行CP分解的步骤: - 首先,将原始数据表示为一个张量对象。 - 然后,使用`cp_als`函数进行CP分解,该函数使用交替最小二乘(ALS)算法进行分解。 - 最后,获取分解后的因子矩阵,可以通过`cp_als`函数的输出参数获得。 示例代码如下: ```matlab % 导入Tensor Toolbox addpath('path_to_tensor_toolbox'); % 构造原始数据张量 X = tensor(data); % 使用cp_als函数进行CP分解 rank = 3; % 设置分解的秩 [A, G] = cp_als(X, rank); % 获取分解后的因子矩阵 factor_matrix1 = A.U{1}; factor_matrix2 = A.U{2}; factor_matrix3 = A.U{3}; ``` 2. 使用MATLAB自带函数进行CP分解: MATLAB中的Tensor Toolbox是一个强大的工具,但如果你只是想快速进行CP分解,也可以使用MATLAB自带的函数`cpd`。 示例代码如下: ```matlab % 构造原始数据张量 X = tensor(data); % 使用cpd函数进行CP分解 rank = 3; % 设置分解的秩 [A, G] = cpd(X, rank); % 获取分解后的因子矩阵 factor_matrix1 = A{1}; factor_matrix2 = A{2}; factor_matrix3 = A{3}; ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值