GraphX SVDPlusPlus Java源码

用Java写了SVDPlusPlus


1:User 只有第一项和第三项是有用的为Pu,Bu,第二项放的是pu + |N(u)|^(-0.5)*sum(y), 第四项放的是|N(u)|^(-0.5),方便计算,。Item有三项有用的,分别为Qi,Yi,bi.计算公式为u + user._3( ) + item._3( ) + blas.ddot( rank, convert2double( item._1( ) ), 1,), item第四项放的是|N(I)|^(-0.5), 没用用处,最后结果中放的是每条边打分偏差的平方的和。

2:原可以看http://www.farseer.cn/2015/08/16/svd-implementation-in-graphx/,梯度下降的求导也简单,五个参数梯度下降,就是1中提到的5个有用的参数。gamma1, gamma2是梯度现将的参数,gamma6,gamma7正规化因子的参数.


个人感觉最后,循环结束后,需要重新计算一次user的第二项pu + |N(u)|^(-0.5)*sum(y)才是最严谨的。

代码如下

import java.io.Serializable;
import java.util.Arrays;
import java.util.Random;


import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.StorageLevels;
import org.apache.spark.graphx.Edge;
import org.apache.spark.graphx.EdgeContext;
import org.apache.spark.graphx.Graph;
import org.apache.spark.graphx.TripletFields;
import org.apache.spark.graphx.VertexRDD;
import org.apache.spark.graphx.lib.SVDPlusPlus;
import org.apache.spark.rdd.RDD;
import org.apache.spark.storage.StorageLevel;


import com.github.fommil.netlib.BLAS;


import scala.Option;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
import scala.reflect.ClassManifestFactory;
import scala.reflect.ClassTag;
import scala.runtime.AbstractFunction1;
import scala.runtime.AbstractFunction2;
import scala.runtime.AbstractFunction3;
import scala.runtime.BoxedUnit;


/**
 * 
 */


public class SVDPlusPlusTest
{
private static final ClassTag<String> tagString = ClassManifestFactory.classType( String.class );
private static final ClassTag<Object> tagObject = ClassManifestFactory.classType( Object.class );
private static final ClassTag<Double> tagDouble = ClassManifestFactory.classType( Double.class );
private static final ClassTag<Double[]> tagDoubleArray = ClassManifestFactory.classType( Double[].class );
private static final ClassTag<Tuple4<Double[], Double[], Double, Double>> tagTuple4 = ClassManifestFactory.classType( Tuple4.class );

private static final BLAS blas = BLAS.getInstance( );
public static void main( String[] args )
{
SparkConf conf = new SparkConf().setAppName( "SVD ++" ).setMaster( "local" );
JavaSparkContext ctx = new JavaSparkContext( conf );

JavaRDD<Tuple2<Object, String>> vertices = ctx.parallelize( Arrays.asList( 
new Tuple2<Object, String>(1L, "a"),
new Tuple2<Object, String>(2L, "b"),
new Tuple2<Object, String>(3L, "c"),
new Tuple2<Object, String>(4L, "d")
) );

JavaRDD<Edge<Double>> edges = ctx.parallelize( Arrays.asList(
new Edge<Double>(1L, 10L, 3.0),
new Edge<Double>(2L, 11L, 4.0)
) );

Graph<String,Double> g = Graph.apply( vertices.rdd( ), edges.rdd( ), "", StorageLevel.MEMORY_ONLY( ), StorageLevel.MEMORY_ONLY( ), tagString, tagDouble );

SVDPlusPlus.Conf svdConf = new SVDPlusPlus.Conf( 2, 20, 0, 5, 0.007, 0.007, 0.005, 0.015 );

Tuple2<Graph<Tuple4<Double[], Double[],Double, Double>, Double>, Double> result = run( g.edges( ), svdConf );


}

private static Tuple2<Graph<Tuple4<Double[], Double[],Double, Double>, Double>, Double> run(RDD<Edge<Double>> edges, SVDPlusPlus.Conf conf)
{
//计算平均数
Tuple2<Long, Double> mean = edges.toJavaRDD( ).map( s-> 
{
return new Tuple2<Long, Double>(1L, s.attr( ));
}).reduce( (t1, t2) -> new Tuple2<Long, Double>(t1._1 + t2._1, t1._2 + t2._2) );

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值