GraphFrame BFS 实现最短路径,可以加路径。只要把vVertices 变成long,boolwan、value和List<Long>
代码如下
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.graphframes.GraphFrame;
import scala.Tuple2;
import scala.reflect.ClassManifestFactory;
import scala.runtime.AbstractFunction1;
/**
*
*/
public class GraphFrameShortPaths
{
private static SQLContext sqlCtx;
private static StructType vType;
private static StructType vNewType;
private static JavaSparkContext ctx;
private static Double NA = Double.MAX_VALUE / 2.0;
public static void main( String[] args )
{
SparkConf conf = new SparkConf().setAppName( "Message short paths" ).setMaster( "local" );
ctx = new JavaSparkContext( conf );
sqlCtx = SQLContext.getOrCreate( ctx.sc( ) );
JavaRDD<Row> verticeRow = ctx.parallelize( Arrays.asList(
RowFactory.create( 1L, "a" ),
RowFactory.create( 2L, "b" ),
RowFactory.create( 3L, "c" ),
RowFactory.create( 4L, "d" ),
RowFactory.create( 5L, "e" )));
JavaRDD<Row> edgeRow = ctx.parallelize( Arrays.asList(
RowFactory.create( 1L, 2L, 10.0 ),
RowFactory.create( 2L, 3L, 30.0 ),
RowFactory.create( 2L, 4L, 20.0 ),
RowFactory.create( 4L, 5L, 80.0 ),
RowFactory.create( 1L, 4L, 5.0 )) );
List<StructField> vList = new ArrayList<StructField>();
vList.add( DataTypes.createStructField( "id", DataTypes.LongType, false ) );
vList.add( DataTypes.createStructField( "name", DataTypes.StringType, true ) );
vType = DataTypes.createStructType( vList );
List<StructField> vNewList = new ArrayList<StructField>();
vNewList.add( DataTypes.createStructField( "id", DataTypes.LongType, false ) );
vNewList.add( DataTypes.createStructField( "maked", DataTypes.BooleanType, true ) );
vNewList.add( DataTypes.createStructField( "value", DataTypes.DoubleType, true ) );
vNewType = DataTypes.createStructType( vNewList );
List<StructField> eList = new ArrayList<StructField>();
eList.add( DataTypes.createStructField( "src", DataTypes.LongType, false ) );
eList.add( DataTypes.createStructField( "dst", DataTypes.LongType, false ) );
eList.add( DataTypes.createStructField( "weight", DataTypes.DoubleType, false ) );
StructType eType = DataTypes.createStructType( eList );
GraphFrame frame = new GraphFrame( sqlCtx.createDataFrame( verticeRow, vType ), sqlCtx.createDataFrame( edgeRow, eType ) );
List<Long> list = new ArrayList<Long>( );
list.add( 1L );
GraphFrame shortPathsFrame = caleBFSShortPaths( covertFrame( frame, list.get( 0 ) ), list );
shortPathsFrame.vertices( ).show( );
ctx.stop( );
}
private static GraphFrame covertFrame(GraphFrame f, long id)
{
RDD<Row> newRow = f.vertices( ).rdd( ).map( new MyFunction1<Row, Row>( )
{
@Override
public Row apply( Row row )
{
return row.getLong( 0 ) == id ?RowFactory.create( row.getLong( 0 ),false, 0.0 ):RowFactory.create( row.getLong( 0 ),false, NA);
}
}, ClassManifestFactory.classType( Row.class ) );
GraphFrame frame = new GraphFrame(sqlCtx.createDataFrame( newRow, vNewType ), f.edges( ));
return frame;
}
public static GraphFrame caleBFSShortPaths( GraphFrame frame, List<Long> nodes )
{
if ( nodes.size( ) == 0 )
{
return frame;
}
DataFrame vertices = frame.vertices( );
List<Row> adds = new ArrayList<Row>( );
List<Long> loopIDs = new ArrayList<Long>( );
for ( long id : nodes )
{
double value = getNodeWeight( frame, id );
adds.add( RowFactory.create( id, true, value ) );
vertices = vertices.filter( "id != " + id );
for ( Row row : frame.edges( ).filter( "src = "
+ id ).collectAsList( ) )
{
long idProp = row.getLong( 0 );
long dst = row.getLong( 1 );
double weight = row.getDouble( 2 );
Tuple2<Boolean, Double> tValue = getDstValue( frame, dst );
if ( value + weight < tValue._2( ) )
{
vertices = vertices.filter( "id != " + dst );
// adds.add( RowFactory.create( dst, tValue._1( ), value +
// weight ) );
adds.add( RowFactory.create( dst, false, value + weight ) );
loopIDs.add( dst );
}
else if ( !tValue._1( ) )
{
loopIDs.add( dst );
}
}
}
JavaRDD<Row> rows = ctx.parallelize( adds );
JavaRDD<Row> newRDD = vertices.javaRDD( ).union( rows );
DataFrame newVertices = sqlCtx.createDataFrame( newRDD, vNewType );
GraphFrame newFrame = new GraphFrame( newVertices, frame.edges( ) );
return caleBFSShortPaths( newFrame, loopIDs );
}
private static double getNodeWeight( GraphFrame frame, long id )
{
return frame.vertices( ).filter( "id ="
+ id ).collectAsList( ).get( 0 ).getDouble( 2 );
}
private static Tuple2<Boolean, Double> getDstValue( GraphFrame frame, long id )
{
Row row = frame.vertices( ).filter( "id ="
+ id ).collectAsList( ).get( 0 );
return new Tuple2<>( row.getBoolean( 1 ), row.getDouble( 2 ) );
}
public static abstract class MyFunction1<T1, R> extends AbstractFunction1<T1, R> implements Serializable
{
}
}