public class SamplerInputFormat extends FileInputFormat<Text, Text> {
static final String PARTITION_FILENAME = "_partition.lst";
static final String SAMPLE_SIZE = "terasort.partitions.sample";
private static JobConf lastConf = null;
private static InputSplit[] lastResult = null;
static class TextSampler implements IndexedSortable {
public ArrayList<Text> records = new ArrayList<Text>();
@Override
public int compare(int arg0, int arg1) {
Text right = records.get(arg0);
Text left = records.get(arg1);
return right.compareTo(left);
}
@Override
public void swap(int arg0, int arg1) {
Text right = records.get(arg0);
Text left = records.get(arg1);
records.set(arg0, left);
records.set(arg1, right);
}
public void addKey(Text key) {
records.add(new Text(key));
}
public Text[] createPartitions(int numPartitions) {
int numRecords = records.size();
if (numPartitions > numRecords) {
throw new IllegalArgumentException("Requested more partitions than input keys (" + numPartitions +
" > " + numRecords + ")");
}
new QuickSort().sort(this, 0, records.size());
float stepSize = numRecords / (float) numPartitions;
Text[] result = new Text[numPartitions - 1];
for (int i = 1; i < numPartitions; ++i) {
result[i - 1] = records.get(Math.round(stepSize * i));
}
return result;
}
}
public static void writePartitionFile(JobConf conf, Path partFile) throws IOException {
SamplerInputFormat inputFormat = new SamplerInputFormat();
TextSampler sampler = new TextSampler();
Text key = new Text();
Text value = new Text();
int partitions = conf.getNumReduceTasks(); // Reducer任务的个数
long sampleSize = conf.getLong(SAMPLE_SIZE, 100); // 采集数据-键值对的个数
InputSplit[] splits = inputFormat.getSplits(conf, conf.getNumMapTasks());// 获得数据分片
int samples = Math.min(10, splits.length);// 采集分片的个数
long recordsPerSample = sampleSize / samples;// 每个分片采集的键值对个数
int sampleStep = splits.length / samples; // 采集分片的步长
long records = 0;
for (int i = 0; i < samples; i++) {
RecordReader<Text, Text> reader = inputFormat.getRecordReader(splits[sampleStep * i], conf, null);
while (reader.next(key, value)) {
sampler.addKey(key);
records += 1;
if ((i + 1) * recordsPerSample <= records) {
break;
}
}
}
FileSystem outFs = partFile.getFileSystem(conf);
if (outFs.exists(partFile)) {
outFs.delete(partFile, false);
}
SequenceFile.Writer writer = SequenceFile.createWriter(outFs, conf, partFile, Text.class, NullWritable.class);
NullWritable nullValue = NullWritable.get();
for (Text split : sampler.createPartitions(partitions)) {
writer.append(split, nullValue);
}
writer.close();
}
static class TeraRecordReader implements RecordReader<Text, Text> {
private LineRecordReader in;
private LongWritable junk = new LongWritable();
private Text line = new Text();
private static int KEY_LENGTH = 10;
public TeraRecordReader(Configuration job, FileSplit split) throws IOException {
in = new LineRecordReader(job, split);
}
@Override
public void close() throws IOException {
in.close();
}
@Override
public Text createKey() {
// TODO Auto-generated method stub
return new Text();
}
@Override
public Text createValue() {
return new Text();
}
@Override
public long getPos() throws IOException {
// TODO Auto-generated method stub
return in.getPos();
}
@Override
public float getProgress() throws IOException {
// TODO Auto-generated method stub
return in.getProgress();
}
@Override
public boolean next(Text arg0, Text arg1) throws IOException {
if (in.next(junk, line)) {
// if (line.getLength() < KEY_LENGTH) {
arg0.set(line);
arg1.clear();
// } else {
// byte[] bytes = line.getBytes(); // 默认知道读取要比较值的前10个字节 作为key
// // 后面的字节作为value;
// arg0.set(bytes, 0, KEY_LENGTH);
// arg1.set(bytes, KEY_LENGTH, line.getLength() - KEY_LENGTH);
// }
return true;
} else {
return false;
}
}
}
@Override
public InputSplit[] getSplits(JobConf conf, int splits) throws IOException {
if (conf == lastConf) {
return lastResult;
}
lastConf = conf;
lastResult = super.getSplits(lastConf, splits);
return lastResult;
}
public org.apache.hadoop.mapred.RecordReader<Text, Text> getRecordReader(InputSplit arg0, JobConf arg1,
Reporter arg2) throws IOException {
return new TeraRecordReader(arg1, (FileSplit) arg0);
}
}
public class SamplerSort extends Configured implements Tool {
// 自定义的Partitioner
public static class TotalOrderPartitioner implements Partitioner<Text, Text> {
private Text[] splitPoints;
public TotalOrderPartitioner() {
}
@Override
public int getPartition(Text arg0, Text arg1, int arg2) {
// TODO Auto-generated method stub
return findPartition(arg0);
}
public void configure(JobConf arg0) {
try {
FileSystem fs = FileSystem.getLocal(arg0);
Path partFile = new Path(SamplerInputFormat.PARTITION_FILENAME);
splitPoints = readPartitions(fs, partFile, arg0); // 读取采集文件
} catch (IOException ie) {
throw new IllegalArgumentException("can't read paritions file", ie);
}
}
public int findPartition(Text key) // 分配可以到多个reduce
{
int len = splitPoints.length;
for (int i = 0; i < len; i++) {
int res = key.compareTo(splitPoints[i]);
if (res > 0 && i < len - 1) {
continue;
} else if (res == 0) {
return i;
} else if (res < 0) {
return i;
} else if (res > 0 && i == len - 1) {
return i + 1;
}
}
return 0;
}
private static Text[] readPartitions(FileSystem fs, Path p, JobConf job) throws IOException {
SequenceFile.Reader reader = new SequenceFile.Reader(fs, p, job);
List<Text> parts = new ArrayList<Text>();
Text key = new Text();
NullWritable value = NullWritable.get();
while (reader.next(key, value)) {
parts.add(key);
}
reader.close();
return parts.toArray(new Text[parts.size()]);
}
}
@Override
public int run(String[] args) throws Exception {
JobConf job = (JobConf) getConf();
// job.set(name, value);
Path inputDir = new Path(args[0]);
inputDir = inputDir.makeQualified(inputDir.getFileSystem(job));
Path partitionFile = new Path(inputDir, SamplerInputFormat.PARTITION_FILENAME);
URI partitionUri = new URI(partitionFile.toString() +
"#" + SamplerInputFormat.PARTITION_FILENAME);
SamplerInputFormat.setInputPaths(job, new Path(args[0]));
FileOutputFormat.setOutputPath(job, new Path(args[1]));
job.setJobName("SamplerTotalSort");
job.setJarByClass(SamplerSort.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(Text.class);
job.setInputFormat(SamplerInputFormat.class);
job.setOutputFormat(TextOutputFormat.class);
job.setPartitionerClass(TotalOrderPartitioner.class);
job.setNumReduceTasks(4);
SamplerInputFormat.writePartitionFile(job, partitionFile); // 数据采集并写入文件
DistributedCache.addCacheFile(partitionUri, job); // 将这个文件作为共享文件
DistributedCache.createSymlink(job);
// SamplerInputFormat.setFinalSync(job, true);
JobClient.runJob(job);
return 0;
}
public static void main(String[] args) throws Exception {
int res = ToolRunner.run(new JobConf(), new SamplerSort(), args);
System.exit(res);
}
}