package org.apache.giraph.examples.lp;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.giraph.comm.WorkerClientRequestProcessor;
import org.apache.giraph.graph.BasicComputation;
import org.apache.giraph.graph.GraphState;
import org.apache.giraph.graph.GraphTaskManager;
import org.apache.giraph.graph.Vertex;
import org.apache.giraph.worker.WorkerContext;
import org.apache.giraph.worker.WorkerGlobalCommUsage;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
public class LPComputation extends BasicComputation<LongWritable, LPVertexValue, NullWritable, LongWritable>{
public static final String SUPERSTEP_COUNT = "giraph.lpComputation.superstepCount";
/**
* Defines the maximum number of vertex migrations.
*/
public static final String NUMBER_OF_ITERATIONS =
"labelpropagation.numberofiterations";
/**
* Default number of vertex migrations if no value is given.
*/
public static final int DEFAULT_NUMBER_OF_ITERATIONS = 50;
/**
* Number of iterations a vertex needs to be stable (i.e., did not migrate).
*/
public static final String STABILISATION_ROUNDS =
"labelpropagation.stabilizationrounds";
/**
* Default number of stabilization rounds.
*/
public static final long DEFAULT_NUMBER_OF_STABILIZATION_ROUNDS = 20; //社区号默认的最长不变超步数
/**
* Stabilization rounds.
*/
private long stabilizationRounds;
/**
* Returns the current new value. This value is based on all incoming
* messages. Depending on the number of messages sent to the vertex, the
* method returns:
* <p/>
* 0 messages: The current value
* <p/>
* 1 message: The minimum of the message and the current vertex value
* <p/>
* >1 messages: The most frequent of all message values
*
* @param vertex The current vertex
* @param messages All incoming messages
* @return the new Value the vertex will become
*/
private long getNewCommunity(
Vertex<LongWritable, LPVertexValue, NullWritable> vertex,
Iterable<LongWritable> messages) {
long newCommunity;
//TODO: create allMessages more efficient
//List<LongWritable> allMessages = Lists.newArrayList(messages);
List<Long> allMessages = new ArrayList<>();
for (LongWritable message : messages) {
allMessages.add(message.get());
}
if (allMessages.isEmpty()) {
// 1. if no messages are received
newCommunity = vertex.getValue().getCurrentCommunity().get();
} else if (allMessages.size() == 1) {
// 2. if just one message are received
newCommunity = Math
.min(vertex.getValue().getCurrentCommunity().get(), allMessages.get(0));
} else {
// 3. if multiple messages are received
newCommunity = getMostFrequent(vertex, allMessages);
}
return newCommunity;
}
/**
* Returns the most frequent value among all received messages.
*
* @param vertex The current vertex
* @param allMessages All messages the current vertex has received
* @return the maximal frequent number in all received messages
*/
private long getMostFrequent(
Vertex<LongWritable, LPVertexValue, NullWritable> vertex,
List<Long> allMessages) {
Collections.sort(allMessages);
long newValue;
int currentCounter = 1;
long currentValue = allMessages.get(0);
int maxCounter = 1;
long maxValue = 1; //这里每个消息值就是一个community的编号
for (int i = 1; i < allMessages.size(); i++) { //找出值(currentValue)重复次数(currentCounter)最大的,结果为maxValue,最大重复次数为maxCounter
if (currentValue == allMessages.get(i)) {
currentCounter++;
if (maxCounter < currentCounter) {
maxCounter = currentCounter;
maxValue = currentValue;
}
} else {
currentCounter = 1;
currentValue = allMessages.get(i);
}
}
// if the frequency of all received messages is one
if (maxCounter == 1) {
// to avoid an oscillating state we use the smaller value
newValue = Math
.min(vertex.getValue().getCurrentCommunity().get(), allMessages.get(0));
} else {
newValue = maxValue;
}
return newValue; //返回重复次数最多的community编号
}
@Override
public void initialize(GraphState graphState,
WorkerClientRequestProcessor<LongWritable, LPVertexValue, NullWritable>
workerClientRequestProcessor,
GraphTaskManager<LongWritable, LPVertexValue, NullWritable>
graphTaskManager,
WorkerGlobalCommUsage workerGlobalCommUsage, WorkerContext workerContext) {
super.initialize(graphState, workerClientRequestProcessor, graphTaskManager,
workerGlobalCommUsage, workerContext);
this.stabilizationRounds = getConf()
.getLong(STABILISATION_ROUNDS, DEFAULT_NUMBER_OF_STABILIZATION_ROUNDS);
}
/**
* The actual LabelPropagation Computation
*
* @param vertex Vertex
* @param messages Messages that were sent to this vertex in the previous
* superstep.
* @throws IOException
*/
@Override
public void compute(Vertex<LongWritable, LPVertexValue, NullWritable> vertex,
Iterable<LongWritable> messages) throws IOException {
if (getSuperstep() == 0) {
sendMessageToAllEdges(vertex, vertex.getValue().getCurrentCommunity());
} else {
if(getSuperstep() > getConf().getInt(SUPERSTEP_COUNT, 0)) { //默认最大为50超步
vertex.voteToHalt() ;
return ;
}
long currentCommunity = vertex.getValue().getCurrentCommunity().get();
long lastCommunity = vertex.getValue().getLastCommunity().get();
long newCommunity = getNewCommunity(vertex, messages); //找出重复次数最大的社区Id
long currentStabilizationRound =
vertex.getValue().getStabilizationRounds();
// increment the stabilization count if vertex wants to stay in the
// same partition
if (lastCommunity == newCommunity) { //该顶点的currentStabilizationRound
currentStabilizationRound++;
vertex.getValue().setStabilizationRounds(currentStabilizationRound);
}
boolean isUnstable = currentStabilizationRound <= stabilizationRounds; //isUnstable为false说明currentStabilizationRound很稳定,
//即社区号比较稳定
boolean mayChange = currentCommunity != newCommunity;
if (mayChange && isUnstable) { //
vertex.getValue().setLastCommunity(new LongWritable(currentCommunity));
vertex.getValue().setCurrentCommunity(new LongWritable(newCommunity));
// reset stabilization counter
vertex.getValue().setStabilizationRounds(0);
sendMessageToAllEdges(vertex, vertex.getValue().getCurrentCommunity());
}
}
vertex.voteToHalt();
}
}
package org.apache.giraph.examples.lp;
import com.google.common.collect.Lists;
import org.apache.giraph.edge.Edge;
import org.apache.giraph.edge.EdgeFactory;
import org.apache.giraph.io.formats.TextVertexInputFormat;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import java.io.IOException;
import java.util.List;
import java.util.regex.Pattern;
public class LPTextVertexInputFormat extends TextVertexInputFormat<LongWritable,
LPVertexValue, NullWritable> {
/**
* Separator of the vertex and neighbors
*/
private static final Pattern SEPARATOR = Pattern.compile("[\t ]");
/**
* {@inheritDoc}
*/
@Override
public TextVertexReader createVertexReader(InputSplit split,
TaskAttemptContext context) throws IOException {
return new VertexReader();
}
/**
* Reads a vertex with two values from an input line.
*/
public class VertexReader extends
TextVertexReaderFromEachLineProcessed<String[]> {
/**
* Vertex id for the current line.
*/
private int id;
/**
* Initial vertex last community.
*/
private long lastCommunity = Long.MAX_VALUE;
/**
* Initial vertex current community. This will be set to the vertex id.
*/
private long currentCommunity;
/**
* Vertex stabilization round.
*/
private long stabilizationRound = 0;
/**
* {@inheritDoc}
*/
@Override
protected String[] preprocessLine(Text line) throws IOException {
String[] tokens = SEPARATOR.split(line.toString());
id = Integer.parseInt(tokens[0]);
currentCommunity = id;
return tokens;
}
/**
* {@inheritDoc}
*/
@Override
protected LongWritable getId(String[] tokens) throws IOException {
return new LongWritable(id);
}
/**
* {@inheritDoc}
*/
@Override
protected LPVertexValue getValue(String[] tokens) throws IOException {
return new LPVertexValue(currentCommunity, lastCommunity,
stabilizationRound);
}
/**
* {@inheritDoc}
*/
@Override
protected Iterable<Edge<LongWritable, NullWritable>> getEdges(
String[] tokens) throws IOException {
List<Edge<LongWritable, NullWritable>> edges = Lists.newArrayList();
for (int n = 1; n < tokens.length; n++) {
if(tokens[n].equals("")) continue ;
edges
.add(EdgeFactory.create(new LongWritable(Long.parseLong(tokens[n]))));
}
if(edges.size()==0)
edges.add(EdgeFactory.create(new LongWritable(id)));
return edges;
}
}
}
package org.apache.giraph.examples.lp;
import org.apache.giraph.graph.Vertex;
import org.apache.giraph.io.formats.TextVertexOutputFormat;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import java.io.IOException;
public class LPTextVertexOutputFormat extends
TextVertexOutputFormat<LongWritable, LPVertexValue, NullWritable> {
private static final String VALUE_TOKEN_SEPARATOR = " ";
/**
* {@inheritDoc}
*/
@Override
public TextVertexWriter createVertexWriter(TaskAttemptContext context)
throws IOException, InterruptedException {
return new LabelPropagationTextVertexLineWriter();
}
/**
* Used to convert a {@link LPVertexValue} to a line in the output file.
*/
private class LabelPropagationTextVertexLineWriter extends
TextVertexWriterToEachLine {
/**
* {@inheritDoc}
*/
@Override
protected Text convertVertexToLine(
Vertex<LongWritable, LPVertexValue, NullWritable> vertex)
throws IOException {
// vertex id
StringBuilder sb = new StringBuilder(vertex.getId().toString());
sb.append(VALUE_TOKEN_SEPARATOR);
// vertex value
sb.append(vertex.getValue().getCurrentCommunity().get());
sb.append(VALUE_TOKEN_SEPARATOR);
// edges
// for (Edge<LongWritable, NullWritable> e : vertex.getEdges()) {
// sb.append(e.getTargetVertexId());
// sb.append(VALUE_TOKEN_SEPARATOR);
// }
return new Text(sb.toString());
}
}
}
package org.apache.giraph.examples.lp;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Writable;
public class LPVertexValue implements Writable {
/**
* The desired partition the vertex want to migrate to.
*/
private long currentCommunity;
/**
* The actual partition.
*/
private long lastCommunity;
/**
* Iterations since last migration.
*/
private long stabilizationRounds;
public LPVertexValue() {
}
public LPVertexValue(long currentCommunity, long lastCommunity,
long stabilizationRounds) {
this.currentCommunity = currentCommunity;
this.lastCommunity = lastCommunity;
this.stabilizationRounds = stabilizationRounds;
}
public void setLastCommunity(LongWritable lastCommunity) {
this.lastCommunity = lastCommunity.get();
}
public void setCurrentCommunity(LongWritable currentCommunity) {
this.currentCommunity = currentCommunity.get();
}
public void setStabilizationRounds(long stabilizationRounds) {
this.stabilizationRounds = stabilizationRounds;
}
public LongWritable getCurrentCommunity() {
return new LongWritable(this.currentCommunity);
}
public LongWritable getLastCommunity() {
return new LongWritable(this.lastCommunity);
}
/**
* Method to get the stabilization round counter
*
* @return the actual counter
*/
public long getStabilizationRounds() {
return stabilizationRounds;
}
@Override
public void write(DataOutput dataOutput) throws IOException {
dataOutput.writeLong(this.currentCommunity);
dataOutput.writeLong(this.lastCommunity);
}
@Override
public void readFields(DataInput dataInput) throws IOException {
this.currentCommunity = dataInput.readLong();
this.lastCommunity = dataInput.readLong();
}
}
测试: giraph ../giraph-examples-1.1.0.jar org.apache.giraph.examples.lp.LPComputation -vif org.apache.giraph.examples.lp.LPTextVertexInputFormat -vip /test/web-BerkStan_final.txt -vof org.apache.giraph.examples.lp.LPTextVertexOutputFormat -op /output -w 6