langchain的RecursiveCharacterTextSplitter类的java实现
import com.sun.deploy.util.StringUtils;
import opennlp.tools.util.StringUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class RecursiveCharacterTextSplitter {
private int chunkSize;
private int chunkOverlap;
private List<String> separators;
public RecursiveCharacterTextSplitter(List<String> separators, int chunkSize, int chunkOverlap) {
this.separators = separators != null ? separators : Arrays.asList("\n\n", "\n", " ", "");
this.chunkSize = chunkSize;
this.chunkOverlap = chunkOverlap;
}
public List<String> splitText(String text) {
List<String> finalChunks = new ArrayList<>();
String separator = "";
for (String s : separators) {
if (s.isEmpty() || text.contains(s)) {
separator = s;
break;
}
}
List<String> splits = separator.isEmpty() ? Arrays.asList(text.split("")) : Arrays.asList(text.split(separator, -1));
List<String> goodSplits = new ArrayList<>();
for (String s : splits) {
if (s.length() < chunkSize) {
goodSplits.add(s);
} else {
if (!goodSplits.isEmpty()) {
List<String> mergedText = mergeSplits(goodSplits, separator);
finalChunks.addAll(mergedText);
goodSplits.clear();
}
List<String> otherInfo = splitText(s);
finalChunks.addAll(otherInfo);
}
}
if (!goodSplits.isEmpty()) {
List<String> mergedText = mergeSplits(goodSplits, separator);
finalChunks.addAll(mergedText);
}
return finalChunks;
}
private String joinTextParts(List<String> docPartsList, String separator) {
String res = StringUtils.join(docPartsList, separator);
res = res.trim();
return res;
}
private List<String> mergeSplits(List<String> splits, String separator) {
int separatorLen = separator.length();
List<String> docList = new ArrayList<>();
List<String> curPartList = new ArrayList<>();
int totalLen = 0;
for (String split : splits) {
int curLen = split.length();
if (totalLen + curLen + (curPartList.size() > 0 ? separatorLen : 0) > chunkSize) {
if (totalLen > chunkSize) {
// 打印警告信息
System.out.println("Warn:Created a chunk of size " + totalLen + " which is longer than the specified " + chunkSize);
}
if (!curPartList.isEmpty()) {
String doc = joinTextParts(curPartList, separator);
if (!StringUtil.isEmpty(doc)) {
docList.add(doc);
}
while ((totalLen > chunkOverlap) || ((totalLen + curLen + (curPartList.size() > 0 ? separatorLen : 0) > chunkSize) && totalLen > 0)) {
totalLen -= curPartList.get(0).length() + (curPartList.size() > 1 ? separatorLen : 0);
curPartList = curPartList.subList(1, curPartList.size());
}
}
}
curPartList.add(split);
totalLen += curLen + (curPartList.size() > 1 ? separatorLen : 0);
}
String doc = joinTextParts(curPartList, separator);
if (!StringUtil.isEmpty(doc)) {
docList.add(doc);
}
return docList;
}
}