在使用hadoop做map/reduce时,有很多场景需要自行实现有多个属性的WritableComparable。以下示例希望对广大开发有所启示。
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.io.WritableUtils;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
public class StatWritable implements WritableComparable<StatWritable> {
private long timestamp;
private int systemId;
private String group;
private String item;
public String getGroup() {
return group;
}
public void setGroup(final String group) {
this.group = group;
}
public String getItem() {
return item;
}
public void setItem(final String item) {
this.item = item;
}
public int getSystemId() {
return systemId;
}
public void setSystemId(final int systemId) {
this.systemId = systemId;
}
public long getTimestamp() {
return timestamp;
}
public void setTimestamp(final long timestamp) {
this.timestamp = timestamp;
}
@Override
public int compareTo(final StatWritable o) {
int cmp = new Long(timestamp).compareTo(o.getTimestamp());
if (cmp != 0) {
return cmp;
}
cmp = systemId - o.getSystemId();
if (cmp != 0) {
return cmp;
}
cmp = group.compareTo(o.getGroup());
if (cmp != 0) {
return cmp;
}
return item.compareTo(o.getItem());
}
/**
* 此方法中写出数据的顺序必须与{@link StatWritable#readFields(java.io.DataInput)}的读取数据一致。
* 根据写入的属性类型调用{@link java.io.DataOutput}中对应的write方法。当写入属性不定长时,必须先写出此字符串的长度后,再写出真实数据
*
* @param out
* @throws IOException
*/
@Override
public void write(final DataOutput out) throws IOException {
out.writeLong(timestamp);
out.writeInt(systemId);
final byte[] groupBytes = group.getBytes();
WritableUtils.writeVInt(out, groupBytes.length);
out.write(groupBytes, 0, groupBytes.length);
final byte[] itemBytes = item.getBytes();
WritableUtils.writeVInt(out, itemBytes.length);
out.write(itemBytes, 0, itemBytes.length);
}
/**
* 此方法中读取数据的顺序必须与{@link StatWritable#write(java.io.DataOutput)}的写入数据一致
* 根据读取的属性类型调用{@link java.io.DataInput}中对应的read方法。当读取属性不定长时,必须先读取此字符串的长度后,再读取真实数据
*
* @param in
* @throws IOException
*/
@Override
public void readFields(final DataInput in) throws IOException {
timestamp = in.readLong();
systemId = in.readInt();
final int groupLength = WritableUtils.readVInt(in);
byte[] groupBytes = new byte[groupLength];
in.readFully(groupBytes, 0, groupLength);
group = new String(groupBytes);
int itemLength = WritableUtils.readVInt(in);
byte[] itemBytes = new byte[itemLength];
in.readFully(itemBytes, 0, itemLength);
item = new String(itemBytes);
}
/**
* 覆盖toString方法,以便记录到map输出文件或reduce输出文件文件
*
* @return
*/
@Override
public String toString() {
return systemId + " " + timestamp + " " + group + " " + item;
}
/**
* 此类为了hadoop快速进行数据比较而设。覆盖{@link com.unionpay.stat.hadoop.domain.StatWritable.Comparator#compare(byte[], int, int, byte[], int, int)}方法时,
* 比较属性的顺序必须与{@link org.apache.hadoop.io.Writable#readFields(java.io.DataInput)}和{@link org.apache.hadoop.io.Writable#write(java.io.DataOutput)}中对属性进行读写操作的顺序一致
*/
public static class Comparator extends WritableComparator {
protected Comparator() {
super(StatWritable.class);
}
@Override
public int compare(final byte[] b1, final int s1, final int l1, final byte[] b2, final int s2, final int l2) {
try {
final long timestampL1 = readLong(b1, s1);
final long timestampL2 = readLong(b2, s2);
final int cmp1 = timestampL1 < timestampL2 ? -1 : (timestampL1 == timestampL2 ? 0 : 1);
if (cmp1 != 0) {
return cmp1;
}
final int startIndex1_1 = s1 + 8;
final int startIndex1_2 = s2 + 8;
final int systemId1 = readInt(b1, startIndex1_1);
final int systemId2 = readInt(b2, startIndex1_2);
final int cmp2 = systemId1 < systemId2 ? -1 : (systemId1 == systemId2 ? 0 : 1);
if (cmp2 != 0) {
return cmp2;
}
final int startIndex2_1 = startIndex1_1 + 4;
final int startIndex2_2 = startIndex1_2 + 4;
final int groupLength1 = WritableUtils.decodeVIntSize(b1[startIndex2_1]) + readVInt(b1, startIndex2_1);
final int groupLength2 = WritableUtils.decodeVIntSize(b2[startIndex2_2]) + readVInt(b2, startIndex2_2);
final int cmp3 = compareBytes(b1, startIndex2_1, groupLength1, b2, startIndex2_2, groupLength2);
if (cmp3 != 0) {
return cmp3;
}
final int startIndex3_1 = startIndex2_1 + groupLength1;
final int startIndex3_2 = startIndex2_2 + groupLength2;
final int itemLength1 = WritableUtils.decodeVIntSize(b1[startIndex3_1]) + readVInt(b1, startIndex3_1);
final int itemLength2 = WritableUtils.decodeVIntSize(b2[startIndex3_2]) + readVInt(b2, startIndex3_2);
return compareBytes(b1, startIndex3_1, itemLength1, b2, startIndex3_2, itemLength2);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
/**
* 注册到hadoop,以便其能识别到
*/
static {
WritableComparator.define(StatWritable.class, new Comparator());
}
}