最近面试遇到了一个问题,就是有个 100G 的文件,里面的内容都是单词,请问在单机笔记本的情况下,怎么使用 MapReduce 的思想完成 WordCount 的计算。
其实这个问题,就是让我给出多线程模拟 MapReduce 进行 WordCount 计算的思路。之前看过一些 MapReduce 的源码,所以按照源码中的思路进行了回答,感觉还不错,于是回来后尝试写了代码。
1)首先,需要模拟出对应的数据,我这里模拟了 1G 左右的数据测试,代码如下
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Random;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
/**
* @author xu
* @desc 生成 word 的模拟数据,这里生成 1G 用来测试,需要注意的是这里生成的文件大小不是一个绝对准确的值。
*/
public class GenWordData {
private static final int TARGET_FILE_SIZE = 1024 * 1024 * 1024; // 1G
private static final int BUFFER_SIZE = 1024 * 1024 * 10; // 10MB
private static final File file = new File("data/words.txt");
private static final BlockingQueue<String> QUEUE = new ArrayBlockingQueue<>(10);
private static final String[] words = new String[]{"hive", "spark", "flink", "clickhouse", "doris", "hadoop", "redis", "kafka"};
public static void main(String[] args) throws IOException, ExecutionException, InterruptedException {
long begin = System.currentTimeMillis();
if (!file.exists()) {
file.getParentFile().mkdirs();
file.createNewFile();
}
CompletableFuture<Void> producerTask = CompletableFuture.runAsync(() -> {
try {
generateData();
} catch (InterruptedException e) {
e.printStackTrace();
}
});
CompletableFuture<Void> consumerTask = CompletableFuture.runAsync(() -> {
try {
writeToFile();
} catch (IOException | InterruptedException e) {
e.printStackTrace();
}
});
CompletableFuture.allOf(producerTask, consumerTask).join();
System.out.println(System.currentTimeMillis() - begin);
}
private static void generateData() throws InterruptedException {
Random r = new Random();
while (file.length() < TARGET_FILE_SIZE) {
StringBuilder builder = new StringBuilder();
while (builder.length() < BUFFER_SIZE){
builder.append(words[r.nextInt(words.length)]).append(" ");
}
builder.append("\n");
QUEUE.put(builder.toString());
}
}
private static void writeToFile() throws IOException, InterruptedException {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(file, true))) {
while (file.length() < TARGET_FILE_SIZE) {
String content = QUEUE.take();
writer.write(content);
}
}
}
}
2)接下来编写多线程处理的代码
public class DiyMapReduce {
public static final File file = new File("data/words.txt");
public static final long segment = 1024 * 1024 * 10; // 10M
public static void main(String[] args) throws IOException, ExecutionException, InterruptedException {
long begin = System.currentTimeMillis();
// 记录每次移动的位置点,首坐标设置为零
final List<Long> pos = new ArrayList<>();
pos.add(0L);
// 给出每个线程处理的范围大小,游标不断前移,但不能超过文件总长度
BufferedReader br = new BufferedReader(new FileReader(file));
long currPos = 0;
while ((currPos + segment) < file.length()) {
// 游标移动到理论值,同时跳过理论值游标字节
currPos += segment;
br.skip(segment);
// 读取游标后的一个字符串,如果是空格(32)、回车(13)、结束符(-1)就结束,否则继续向后找。
int chr = br.read();
currPos++;
while (chr != 32 && chr != 13 && chr != -1) {
chr = br.read();
currPos++;
}
pos.add(currPos);
}
// 把文件总字节数存放到集合中
pos.add(file.length());
// 使用固定线程池,因为 word count 算是 cpu 密集型,所以这里线程数就等于 cpu 核数
ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
ArrayList<CompletableFuture<HashMap<String, Integer>>> futureList = new ArrayList<>();
for (int i = 0; i < pos.size() - 1; i++) {
Compute compute = new Compute((pos.get(i + 1) - pos.get(i)), pos.get(i));
CompletableFuture<HashMap<String, Integer>> future = CompletableFuture.supplyAsync(compute, executor);
futureList.add(future);
}
CompletableFuture<Void> allOf = CompletableFuture.allOf(futureList.toArray(new CompletableFuture[0]));
allOf.join();
executor.shutdown();
// 最后合并所有集合中的数组
Map<String, Integer> res = futureList.stream()
.map(CompletableFuture::join)
.flatMap(map -> map.entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, Integer::sum));
System.out.println(res);
System.out.println(System.currentTimeMillis() - begin);
}
public static class Compute implements Supplier<HashMap<String, Integer>> {
private final long readSize;
private final long skipSize;
public Compute(long readSize, long skipSize) {
this.readSize = readSize;
this.skipSize = skipSize;
}
@Override
public HashMap<String, Integer> get() {
byte[] bytes = new byte[0];
try (FileChannel fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ)) {
MappedByteBuffer buffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, skipSize, readSize);
bytes = new byte[buffer.remaining()];
buffer.get(bytes);
} catch (IOException e) {
e.printStackTrace();
}
// 把字符串组装成 map,key 为单词,value 为出现的次数
HashMap<String, Integer> map = new HashMap<>();
String[] words = new String(bytes).split("\\s+");
for (String word : words) {
map.compute(word.trim(), (key, value) -> value == null ? 1 : ++value);
}
return map;
}
}
}
3)优化
上述代码处理 1G 左右的数据耗时大约在 65 秒左右,而且 CPU 和内存占用也挺高的。去看了下火焰图发现,大部分耗时都在 split 函数上。
由于都是空格,也不涉及什么正则啥的,于是换成 StringTokenizer 类来切割,代码如下
StringTokenizer tokenizer = new StringTokenizer(new String(bytes), " ");
while (tokenizer.hasMoreTokens()) {
map.compute(tokenizer.nextToken().trim(), (key, value) -> value == null ? 1 : ++value);
}
大概耗时变成了 15 秒左右,大约快了 3 倍左右。当然这也是业务比较简单,StringTokenizer 是没法做一些正则的。