ALS java实现

用java实现了简单的ALS,Spark ALS的逻辑,Spark 是RDD分块计算。代码如下,

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;

import org.apache.spark.util.random.XORShiftRandom;
import org.netlib.util.intW;

import com.github.fommil.netlib.BLAS;
import com.github.fommil.netlib.LAPACK;

import scala.Tuple2;
import scala.util.hashing.package$;

public class ALSCaleTest
{
private static final Long SEED = 10L;
private static final BLAS blas = BLAS.getInstance( );
private static final LAPACK lapack = LAPACK.getInstance( );
private static final int MAXITER = 10;
private static final int RANK = 2;
private static final double LAMBDA = 0.01;
public static <E> void main( String[] args )
{
List<RatingCale> list =Arrays.asList( 
new RatingCale( 1, 11, 3.0f ),
new RatingCale( 1, 12, 4.0f ),
new RatingCale( 2, 12, 3.0f ),
new RatingCale( 2, 13, 4.5f ),
new RatingCale( 3, 11, 3.0f ),
new RatingCale( 3, 12, 2.0f ));

List<Integer> users = userNumber( list );
Random seedGen = new XORShiftRandom( SEED );
List<Tuple2<Integer, float[]>> userFactors = initFactors( users, RANK,  seedGen.nextLong( ));
List<Tuple2<Integer, float[]>> itemFactors = null;
Map<Integer, List<Tuple2<Integer, Float>>> userBlocks = userBlocks( list );
Map<Integer, List<Tuple2<Integer, Float>>> itemBlocks = itemBlocks( list );


for (int i=0; i<MAXITER; i++)
{
itemFactors = computeFactors( itemBlocks, userFactors );
userFactors = computeFactors( userBlocks, itemFactors );
}

List<Tuple2<Integer, Float>> result = recommendProducts(1,13, userFactors, itemFactors );
System.out.println( result );
System.out.println( "Done" );
}

private static List<Tuple2<Integer, float[]>> computeFactors(Map<Integer, List<Tuple2<Integer, Float>>> map, List<Tuple2<Integer, float[]>> factors)
{
List<Tuple2<Integer, float[]>> retValue = new ArrayList<Tuple2<Integer, float[]>>();
Iterator<Entry<Integer, List<Tuple2<Integer, Float>>>> iter = map.entrySet( ).iterator( );
NormalEquation ne = new NormalEquation( RANK );
while(iter.hasNext( ))
{
Entry<Integer, List<Tuple2<Integer, Float>>> value = iter.next( );
int srcId = value.getKey( );
List<Tuple2<Integer, Float>> list = value.getValue( );
ne.reset( );

for (Tuple2<Integer, Float> t:list)
{
ne.add( getFactorFromList( factors, t._1 ), t._2, 1.0 );
}

float[] newFactors = choleskySolver( ne, LAMBDA*list.size( ) );
retValue.add( new Tuple2<Integer, float[]>(srcId, newFactors) );
}
return retValue;
}

private static float[] getFactorFromList(List<Tuple2<Integer, float[]>> list, int id)
{
for (Tuple2<Integer, float[]> t:list)
{
if (t._1 == id)
{
return t._2;
}
}
throw new RuntimeException( "Error" );
}

private static Map<Integer, List<Tuple2<Integer, Float>>> userBlocks(List<RatingCale> list)
{
Map<Integer, List<Tuple2<Integer, Float>>> retValue = new HashMap<Integer, List<Tuple2<Integer, Float>>>( );
for (RatingCale c:list)
{
List<Tuple2<Integer, Float>> tuples = null;
if (retValue.containsKey( c.user ))
{
tuples = retValue.get( c.user );
tuples.add( new Tuple2<Integer, Float> (c.item, c.rating));
}
else
{
tuples = new ArrayList<Tuple2<Integer, Float>>( );
tuples.add( new Tuple2<Integer, Float> (c.item, c.rating));
retValue.put( c.user, tuples );
}
}
return retValue;
}

private static Map<Integer, List<Tuple2<Integer, Float>>> itemBlocks(List<RatingCale> list)
{
Map<Integer, List<Tuple2<Integer, Float>>> retValue = new HashMap<Integer, List<Tuple2<Integer, Float>>>( );
for (RatingCale c:list)
{
List<Tuple2<Integer, Float>> tuples = null;
if (retValue.containsKey( c.item ))
{
tuples = retValue.get( c.item );
tuples.add( new Tuple2<Integer, Float> (c.user, c.rating));
}
else
{
tuples = new ArrayList<Tuple2<Integer, Float>>( );
tuples.add( new Tuple2<Integer, Float> (c.user, c.rating));
retValue.put( c.item, tuples );
}
}
return retValue;
}

private static List<Integer> userNumber(List<RatingCale> list)
{
List<Integer> retValue = new ArrayList<Integer>();
for (RatingCale c:list)
{
if (!retValue.contains( c.user ))
{
retValue.add( c.user );
}
}
return retValue;
}

private static List<Tuple2<Integer, float[]>> initFactors(List<Integer> list, int rank, long seed)
{
List<Tuple2<Integer, float[]>> retValue = new ArrayList<Tuple2<Integer, float[]>>();
Random random = new XORShiftRandom( package$.MODULE$.byteswap64( seed ) );
for (int i=0; i<list.size( ); i++)
{
float[] factor = new float[rank];
for (int j=0; j<rank; j++)
{
factor[j] = ((Double)random.nextGaussian( )).floatValue( ); 
}
float nrm = blas.snrm2( rank, factor, 1 );
blas.sscal( rank, 1.0f / nrm, factor, 1 );
retValue.add( new Tuple2<Integer, float[]>(list.get( i ), factor) );
}
return retValue;
}
private static float[] choleskySolver( NormalEquation ne, double lambda )
{
int k = ne.k;
int i = 0;
int j = 2;
while (i < ne.trik)
{
ne.ata[i]  = ne.ata[i] + lambda;
i = i + j;
j = j + 1;
}
solve( ne.ata, ne.atb );
float[] x = new float[k];
i=0;
while( i< k)
{
x[i] = ((Double)ne.atb[i]).floatValue( );
i = i + 1;
}
ne.reset( );
return x;
}

private static double[] solve(double[] A, double[] bx)
{
int k = bx.length;
intW info = new intW( 0 );
lapack.dppsv("U", k, 1, A, bx, k, info);

if (info.val != 0)
{
throw new RuntimeException( "LAPACK run error" );
}
return bx;
}
private static class RatingCale
{
int user;
int item;
float rating;
public RatingCale( int user, int item, float rating )
{
super( );
this.user = user;
this.item = item;
this.rating = rating;
}
}

private static List<Tuple2<Integer, Float>> recommendProducts( int user, int num, List<Tuple2<Integer, float[]>> userFactors, List<Tuple2<Integer, float[]>> itemFactors)
{
float[] userFactor = null;
for (Tuple2<Integer, float[]> t: userFactors)
{
if (t._1 == user)
{
userFactor = t._2;
break;
}
}
if (userFactor == null)
{
throw new RuntimeException( "Error" );
}
return recommend( userFactor, itemFactors, num );
}
private static List<Tuple2<Integer, Float>> recommend(
float[] recommendToFeatures,
List<Tuple2<Integer, float[]>> recommendableFeatures, int num )
{
List<Tuple2<Integer, Float>> retValue = new ArrayList<Tuple2<Integer, Float>>( );
for (Tuple2<Integer, float[]> t: recommendableFeatures)
{
float value = blas.sdot( t._2.length, recommendToFeatures, 1, t._2, 1 );
retValue.add( new Tuple2<Integer, Float>(t._1, value) );
}
retValue.sort( new Comparator<Tuple2<Integer, Float>>()
{


@Override
public int compare( Tuple2<Integer, Float> o1,
Tuple2<Integer, Float> o2 )
{
return o1._2  < o2._2( ) ? 1:-1;
}

});
if (retValue.size( ) > num)
{
return retValue.subList( 0, num );
}
else
{
return retValue;
}


}

private static double[] convert2double(float[] fs)
{
double[] ds = new double[fs.length];
for (int i=0; i<fs.length; i++)
{
ds[i] = fs[i];
}

return ds;
}

private static class NormalEquation
{


private static final String upper = "U";
private int k;
private int trik;
private double[] ata;
private double[] atb;
private double[] da;


public NormalEquation( int k )
{
super( );
this.k = k;
trik = k * ( k + 1 ) / 2;
ata = new double[trik];
atb = new double[k];
da = new double[k];
}


private void copyToDouble( float[] a )
{
int i = 0;
while ( i < k )
{
da[i] = a[i];
i = i + 1;
}
}


NormalEquation add( float[] a, double b, double c )
{
copyToDouble( a );
blas.dspr( upper, k, c, da, 1, ata );
if ( b != 0.0 )
{
blas.daxpy( k, c * b, da, 1, atb, 1 );
}
return this;
}


NormalEquation merge( NormalEquation other )
{
blas.daxpy( ata.length, 1.0, other.ata, 1, ata, 1 );
blas.daxpy( atb.length, 1.0, other.atb, 1, atb, 1 );
return this;
}


void reset( )
{
Arrays.fill( ata, 0.0 );
Arrays.fill( atb, 0.0 );
}
}


}

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值