最近大模型比较流行,后端结合Dify做了些应用。
现在我有两个Java类,分别是
DocumentImportDatasetTask:文档导入知识库任务
DocumentClassificationTask:文档分类任务
他们的逻辑很相似,首先他们都要扫描一个目录下的所有文件,DocumentImportDatasetTask是将其中的文档筛选出来,然后倒入Dify的知识库中进行向量化,DocumentClassificationTask也是一样,先扫描出所有的文档,然后组个发送给本地的Ollama搭建的qwen模型,对其进行分类。
大致逻辑
- DocumentImportDatasetTask
- DocumentClassificationTask
可以看到这两个类被我最终优化后的代码是这样的,其中scanDirectory()方法和consumerThread()方法都是我提取出的藕荷方法
扫描目录
他们都需要扫描指定目录下的所有文件,然后放入一个队列中,我们首先先将扫描文件的逻辑提取出来。
@Slf4j
public class FileUtil {
/**
* 计算给定文件的MD5哈希值。
*
* @param file 需要计算MD5的文件
* @return 文件的MD5哈希值的字符串表示
*/
public static String getFileMd5(File file) {
try (FileInputStream fis = new FileInputStream(file)) {
// 创建MessageDigest实例,使用MD5算法
MessageDigest md = MessageDigest.getInstance("MD5");
byte[] buffer = new byte[1024];
int len;
// 读取文件数据并更新摘要
while ((len = fis.read(buffer)) != -1) {
md.update(buffer, 0, len);
}
// 获取MD5哈希值,结果为字节数组
byte[] digest = md.digest();
// 将字节数组转换为十六进制字符串表示
return bytesToHex(digest);
} catch (Exception e) {
log.error("获取文件MD5失败", e);
return null;
}
}
/**
* 将字节数组转换为十六进制字符串。
*
* @param bytes 要转换的字节数组
* @return 转换后的十六进制字符串
*/
private static String bytesToHex(byte[] bytes) {
StringBuilder sb = new StringBuilder();
for (byte b : bytes) {
// 将每个字节转换为十六进制字符串,并连接到结果中
sb.append(String.format("%02x", b));
}
return sb.toString();
}
/**
* 扫描目录
*
* @param directoryPath 目录路径
*/
public static void scanDirectory(String directoryPath, Queue<File> queue) {
File directory = new File(directoryPath);
// 调用递归方法
if (directory.exists() && directory.isDirectory()) {
scanDirectoryRecursive(directory, queue);
}
}
/**
* 递归扫描目录并过滤掉文件夹
*
* @param directory 当前目录
* @param queue 文件队列
*/
private static void scanDirectoryRecursive(File directory, Queue<File> queue) {
File[] files = directory.listFiles();
if (files != null) {
for (File file : files) {
if (file.isDirectory()) {
// 递归扫描子目录
scanDirectoryRecursive(file, queue);
} else {
if (queue instanceof ConcurrentLinkedQueue) {
queue.offer(file);
} else if (queue instanceof LinkedBlockingQueue) {
try {
((LinkedBlockingQueue<File>) queue).put(file);
} catch (InterruptedException e) {
log.error("文件队列阻塞", e);
}
}
}
}
}
}
}
这里的扫描文件后入队时我做了一下判断,是因为DocumentImportDatasetTask和DocumentClassificationTask场景不一样。DocumentImportDatasetTask是一边扫描目录,一边调用Dify的API进行文档向量化的操作,因为入队和出队都是很频繁的,基本上不会造成队列容量太大(消息积压),所以用的是ConcurrentLinkedQueue这么一种无锁无阻塞的高性能队列。
而DocumentClassificationTask是一边扫描目录,一边需要调用大模型回答,大模型的响应速度相比于操作系统扫盘来说慢多了,所以必然会引起消息的堆积,为了防止OOM,所以我采用了LinkedBlockingQueue这么一种阻塞的队列,当队列的容量满了的时候,再执行入队操作时就会阻塞当前线程,直到队列有空闲在继续入队。
创建线程池
这两个类都需要使用多线程的方式来消费队列里面的消息,多线程调用Dify和Ollama的API,那就意味着提交线程池任务的逻辑也会高度的类似,唯一不同的就是线程池里面的任务逻辑,我们需要做的就是将提交线程池任务的代码抽象出来,具体的任务由具体的业务来实现,于是我采用了抽象类的方式解决我的需求。
@Slf4j
public abstract class FileConsumerThread {
public abstract void execute(Queue<File> queue) throws Exception;
public void consumerThread(Queue<File> queue, AtomicBoolean running, int threadNum, ExecutorService threadPoolExecutor) {
for (int i = 0; i < threadNum; i++) {
int finalI = i;
threadPoolExecutor.execute(() -> {
try {
Thread.sleep(1000); // 等待1秒
} catch (InterruptedException e) {
log.error("线程" + finalI + ":线程异常", e);
}
log.info("线程" + finalI + "开始执行");
while (true) {
if (queue.isEmpty() && !running.get()) {
log.info("线程" + finalI + ":结束");
break;
} else {
try {
execute(queue);
} catch (Exception e) {
log.error("线程" + finalI + ":线程异常", e);
}
}
}
});
}
}
}
这里我定义了一个文件消费线程的抽象类,里面有一个抽象方法execute()需要子类自己实现,consumerThread()类是公共的线程池提交任务逻辑,继承这个父类就获得此段方法,实现了解耦。
通过这么一种方式,这两个类都有了consumerThread()方法,就避免了重复的提交任务代码
其他姿势
使用接口也可以实现这一功能,Java8在接口允许使用default关键字对接口的方法做一个默认的实现,就类似抽象类的非抽象方法
类似this
public interface FileConsumerThread {
public void execute(Queue<File> queue) throws Exception;
default void consumerThread(Queue<File> queue, AtomicBoolean running, int threadNum, ExecutorService threadPoolExecutor) {
for (int i = 0; i < threadNum; i++) {
int finalI = i;
threadPoolExecutor.execute(() -> {
try {
Thread.sleep(1000); // 等待1秒
} catch (InterruptedException e) {
log.error("线程" + finalI + ":线程异常", e);
}
log.info("线程" + finalI + "开始执行");
while (true) {
if (queue.isEmpty() && !running.get()) {
log.info("线程" + finalI + ":结束");
break;
} else {
try {
execute(queue);
} catch (Exception e) {
log.error("线程" + finalI + ":线程异常", e);
}
}
}
});
}
}
}
但是我个人不太喜欢这么做,我认为接口就应该高度的抽象,接口和抽象类不能相互代替
- 接口通常用于定义类应该具有的方法,但不涉及具体实现的情况。接口提供了一种约定,让类表明自己能够做什么,而不必关心如何做。
- 抽象类用于定义一些子类的通用特性,它可以包含具体的方法实现,子类可以直接继承并重写其中的抽象方法,同时可以继承实现的方法。
好了,分享到这,感兴趣的也看看我的其他文章,蟹蟹~