Label Propagation算法Giraph实现

package org.apache.giraph.examples.lp;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import org.apache.giraph.comm.WorkerClientRequestProcessor;
import org.apache.giraph.graph.BasicComputation;
import org.apache.giraph.graph.GraphState;
import org.apache.giraph.graph.GraphTaskManager;
import org.apache.giraph.graph.Vertex;
import org.apache.giraph.worker.WorkerContext;
import org.apache.giraph.worker.WorkerGlobalCommUsage;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;

public class LPComputation  extends BasicComputation<LongWritable, LPVertexValue, NullWritable, LongWritable>{


	public static final String SUPERSTEP_COUNT =  "giraph.lpComputation.superstepCount";
	  /** 
	   * Defines the maximum number of vertex migrations.
	   */
	  public static final String NUMBER_OF_ITERATIONS =
	    "labelpropagation.numberofiterations";
	  /**
	   * Default number of vertex migrations if no value is given.
	   */
	  public static final int DEFAULT_NUMBER_OF_ITERATIONS = 50;
	  /**
	   * Number of iterations a vertex needs to be stable (i.e., did not migrate).
	   */
	  public static final String STABILISATION_ROUNDS =
	    "labelpropagation.stabilizationrounds";
	  /**
	   * Default number of stabilization rounds.
	   */
	  public static final long DEFAULT_NUMBER_OF_STABILIZATION_ROUNDS = 20;  //社区号默认的最长不变超步数
	  /**
	   * Stabilization rounds.
	   */
	  private long stabilizationRounds;

	  /**
	   * Returns the current new value. This value is based on all incoming
	   * messages. Depending on the number of messages sent to the vertex, the
	   * method returns:
	   * <p/>
	   * 0 messages:   The current value
	   * <p/>
	   * 1 message:    The minimum of the message and the current vertex value
	   * <p/>
	   * >1 messages:  The most frequent of all message values
	   *
	   * @param vertex   The current vertex
	   * @param messages All incoming messages
	   * @return the new Value the vertex will become
	   */
	  private long getNewCommunity(
	    Vertex<LongWritable, LPVertexValue, NullWritable> vertex,
	    Iterable<LongWritable> messages) {
	    long newCommunity;
	    //TODO: create allMessages more efficient
	    //List<LongWritable> allMessages = Lists.newArrayList(messages);
	    List<Long> allMessages = new ArrayList<>();
	    for (LongWritable message : messages) {
	      allMessages.add(message.get());
	    }
	    if (allMessages.isEmpty()) {
	      // 1. if no messages are received
	      newCommunity = vertex.getValue().getCurrentCommunity().get();
	    } else if (allMessages.size() == 1) {
	      // 2. if just one message are received
	      newCommunity = Math
	        .min(vertex.getValue().getCurrentCommunity().get(), allMessages.get(0));
	    } else {
	      // 3. if multiple messages are received
	      newCommunity = getMostFrequent(vertex, allMessages);
	    }
	    return newCommunity;
	  }

	  /**
	   * Returns the most frequent value among all received messages.
	   *
	   * @param vertex      The current vertex
	   * @param allMessages All messages the current vertex has received
	   * @return the maximal frequent number in all received messages
	   */
	  private long getMostFrequent(
	    Vertex<LongWritable, LPVertexValue, NullWritable> vertex,
	    List<Long> allMessages) {
	    Collections.sort(allMessages);
	    long newValue;
	    int currentCounter = 1;
	    long currentValue = allMessages.get(0);
	    int maxCounter = 1;
	    long maxValue = 1;        //这里每个消息值就是一个community的编号
	    for (int i = 1; i < allMessages.size(); i++) {  //找出值(currentValue)重复次数(currentCounter)最大的,结果为maxValue,最大重复次数为maxCounter
	      if (currentValue == allMessages.get(i)) {
	        currentCounter++;
	        if (maxCounter < currentCounter) {
	          maxCounter = currentCounter;
	          maxValue = currentValue;
	        }
	      } else {
	        currentCounter = 1;
	        currentValue = allMessages.get(i);
	      }
	    }
	    // if the frequency of all received messages is one
	    if (maxCounter == 1) {
	      // to avoid an oscillating state we use the smaller value
	      newValue = Math
	        .min(vertex.getValue().getCurrentCommunity().get(), allMessages.get(0));
	    } else {
	      newValue = maxValue;
	    }
	    return newValue;  //返回重复次数最多的community编号
	  }

	  @Override
	  public void initialize(GraphState graphState,
	    WorkerClientRequestProcessor<LongWritable, LPVertexValue, NullWritable>
	      workerClientRequestProcessor,
	    GraphTaskManager<LongWritable, LPVertexValue, NullWritable>
	      graphTaskManager,
	    WorkerGlobalCommUsage workerGlobalCommUsage, WorkerContext workerContext) {
	    super.initialize(graphState, workerClientRequestProcessor, graphTaskManager,
	      workerGlobalCommUsage, workerContext);
	    this.stabilizationRounds = getConf()
	      .getLong(STABILISATION_ROUNDS, DEFAULT_NUMBER_OF_STABILIZATION_ROUNDS);
	  }

	  /**
	   * The actual LabelPropagation Computation
	   *
	   * @param vertex   Vertex
	   * @param messages Messages that were sent to this vertex in the previous
	   *                 superstep.
	   * @throws IOException
	   */
	  @Override
	  public void compute(Vertex<LongWritable, LPVertexValue, NullWritable> vertex,
	    Iterable<LongWritable> messages) throws IOException {
	    if (getSuperstep() == 0) {
	      sendMessageToAllEdges(vertex, vertex.getValue().getCurrentCommunity());
	    } else {
	    	if(getSuperstep() >  getConf().getInt(SUPERSTEP_COUNT, 0)) {  //默认最大为50超步
	    		vertex.voteToHalt() ; 
                return ;	    		
	    	}
	    	
	      long currentCommunity = vertex.getValue().getCurrentCommunity().get();
	      long lastCommunity = vertex.getValue().getLastCommunity().get();
	      long newCommunity = getNewCommunity(vertex, messages); //找出重复次数最大的社区Id 
	      long currentStabilizationRound =
	        vertex.getValue().getStabilizationRounds();

	      // increment the stabilization count if vertex wants to stay in the
	      // same partition
	      if (lastCommunity == newCommunity) { //该顶点的currentStabilizationRound
	        currentStabilizationRound++;
	        vertex.getValue().setStabilizationRounds(currentStabilizationRound);
	      }

	      boolean isUnstable = currentStabilizationRound <= stabilizationRounds; //isUnstable为false说明currentStabilizationRound很稳定,
	                                                                            //即社区号比较稳定
	      boolean mayChange = currentCommunity != newCommunity;
	      if (mayChange && isUnstable) { //
	        vertex.getValue().setLastCommunity(new LongWritable(currentCommunity));
	        vertex.getValue().setCurrentCommunity(new LongWritable(newCommunity));
	        // reset stabilization counter
	        vertex.getValue().setStabilizationRounds(0);
	        sendMessageToAllEdges(vertex, vertex.getValue().getCurrentCommunity());
	      }
	    }
	    vertex.voteToHalt();
	  }

}

package org.apache.giraph.examples.lp;

import com.google.common.collect.Lists;
import org.apache.giraph.edge.Edge;
import org.apache.giraph.edge.EdgeFactory;
import org.apache.giraph.io.formats.TextVertexInputFormat;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import java.io.IOException;
import java.util.List;
import java.util.regex.Pattern;

public class LPTextVertexInputFormat extends TextVertexInputFormat<LongWritable, 
 LPVertexValue, NullWritable> {
	 /**
      * Separator of the vertex and neighbors
      */
	  private static final Pattern SEPARATOR = Pattern.compile("[\t ]");

	  /**
	   * {@inheritDoc}
	   */
	  @Override
	  public TextVertexReader createVertexReader(InputSplit split,
	    TaskAttemptContext context) throws IOException {
	    return new VertexReader();
	  }

	  /**
	   * Reads a vertex with two values from an input line.
	   */
	  public class VertexReader extends
	    TextVertexReaderFromEachLineProcessed<String[]> {
	    /**
	     * Vertex id for the current line.
	     */
	    private int id;
	    /**
	     * Initial vertex last community.
	     */
	    private long lastCommunity = Long.MAX_VALUE;
	    /**
	     * Initial vertex current community. This will be set to the vertex id.
	     */
	    private long currentCommunity;
	    /**
	     * Vertex stabilization round.
	     */
	    private long stabilizationRound = 0;

	    /**
	     * {@inheritDoc}
	     */
	    @Override
	    protected String[] preprocessLine(Text line) throws IOException {
	      String[] tokens = SEPARATOR.split(line.toString());
	      id = Integer.parseInt(tokens[0]);
	      currentCommunity = id;
	      return tokens;
	    }

	    /**
	     * {@inheritDoc}
	     */
	    @Override
	    protected LongWritable getId(String[] tokens) throws IOException {
	      return new LongWritable(id);
	    }

	    /**
	     * {@inheritDoc}
	     */
	    @Override
	    protected LPVertexValue getValue(String[] tokens) throws IOException {
	      return new LPVertexValue(currentCommunity, lastCommunity,
	        stabilizationRound);
	    }

	    /**
	     * {@inheritDoc}
	     */
	    @Override
	    protected Iterable<Edge<LongWritable, NullWritable>> getEdges(
	      String[] tokens) throws IOException {
	      List<Edge<LongWritable, NullWritable>> edges = Lists.newArrayList();
	      for (int n = 1; n < tokens.length; n++) {
	    	  if(tokens[n].equals("")) continue ;
	        edges
	          .add(EdgeFactory.create(new LongWritable(Long.parseLong(tokens[n]))));
	      }
	      if(edges.size()==0)
	    	  edges.add(EdgeFactory.create(new LongWritable(id)));
	      return edges;
	    }
	  }
}

package org.apache.giraph.examples.lp;

import org.apache.giraph.graph.Vertex;
import org.apache.giraph.io.formats.TextVertexOutputFormat;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import java.io.IOException;

public class LPTextVertexOutputFormat extends
		TextVertexOutputFormat<LongWritable, LPVertexValue, NullWritable> {
	private static final String VALUE_TOKEN_SEPARATOR = " ";

	/**
	 * {@inheritDoc}
	 */
	@Override
	public TextVertexWriter createVertexWriter(TaskAttemptContext context)
			throws IOException, InterruptedException {
		return new LabelPropagationTextVertexLineWriter();
	}

	/**
	 * Used to convert a {@link LPVertexValue} to a line in the output file.
	 */
	private class LabelPropagationTextVertexLineWriter extends
			TextVertexWriterToEachLine {
		/**
		 * {@inheritDoc}
		 */
		@Override
		protected Text convertVertexToLine(
				Vertex<LongWritable, LPVertexValue, NullWritable> vertex)
				throws IOException {
			// vertex id
			StringBuilder sb = new StringBuilder(vertex.getId().toString());
			sb.append(VALUE_TOKEN_SEPARATOR);
			// vertex value
			sb.append(vertex.getValue().getCurrentCommunity().get());
			sb.append(VALUE_TOKEN_SEPARATOR);
			// edges
//			for (Edge<LongWritable, NullWritable> e : vertex.getEdges()) {
//				sb.append(e.getTargetVertexId());
//				sb.append(VALUE_TOKEN_SEPARATOR);
//			}
			return new Text(sb.toString());
		}
	}
}

package org.apache.giraph.examples.lp;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;

import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Writable;

public class LPVertexValue implements Writable {
	/**
	 * The desired partition the vertex want to migrate to.
	 */
	private long currentCommunity;
	/**
	 * The actual partition.
	 */
	private long lastCommunity;
	/**
	 * Iterations since last migration.
	 */
	private long stabilizationRounds;

	public LPVertexValue() {
	}

	public LPVertexValue(long currentCommunity, long lastCommunity,
			long stabilizationRounds) {
		this.currentCommunity = currentCommunity;
		this.lastCommunity = lastCommunity;
		this.stabilizationRounds = stabilizationRounds;
	}

	public void setLastCommunity(LongWritable lastCommunity) {
		this.lastCommunity = lastCommunity.get();
	}

	public void setCurrentCommunity(LongWritable currentCommunity) {
		this.currentCommunity = currentCommunity.get();
	}

	public void setStabilizationRounds(long stabilizationRounds) {
		this.stabilizationRounds = stabilizationRounds;
	}

	public LongWritable getCurrentCommunity() {
		return new LongWritable(this.currentCommunity);
	}

	public LongWritable getLastCommunity() {
		return new LongWritable(this.lastCommunity);
	}

	/**
	 * Method to get the stabilization round counter
	 * 
	 * @return the actual counter
	 */
	public long getStabilizationRounds() {
		return stabilizationRounds;
	}

	@Override
	public void write(DataOutput dataOutput) throws IOException {
		dataOutput.writeLong(this.currentCommunity);
		dataOutput.writeLong(this.lastCommunity);
	}

	@Override
	public void readFields(DataInput dataInput) throws IOException {
		this.currentCommunity = dataInput.readLong();
		this.lastCommunity = dataInput.readLong();
	}

}

测试:  giraph ../giraph-examples-1.1.0.jar org.apache.giraph.examples.lp.LPComputation -vif org.apache.giraph.examples.lp.LPTextVertexInputFormat -vip /test/web-BerkStan_final.txt -vof org.apache.giraph.examples.lp.LPTextVertexOutputFormat -op /output -w   6

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值