@Tag(name = "pdf/word/图片文字识别")
public class OcrController extends BaseController {
@Autowired
private OcrService ocrService;
@Autowired
private BaiduOcrServiceImpl baiduOcrService;
@PostMapping("/recognize-text")
@Operation(summary = "pdf/word识别文字", description = "识别")
public String recognizeText(@RequestParam("file") MultipartFile file) {
return ocrService.recognizeText(file);
}
}
package com.jt.console.service.impl;
import com.jt.common.beans.ServiceAssert;
import com.jt.console.service.OcrService;
import org.apache.pdfbox.cos.COSName;
import org.apache.pdfbox.pdmodel.PDDocument;
import org.apache.pdfbox.pdmodel.PDPage;
import org.apache.pdfbox.pdmodel.PDPageTree;
import org.apache.pdfbox.pdmodel.PDResources;
import org.apache.pdfbox.pdmodel.graphics.PDXObject;
import org.apache.pdfbox.pdmodel.graphics.image.PDImageXObject;
import org.apache.pdfbox.text.PDFTextStripper;
import org.apache.poi.hwpf.HWPFDocument;
import org.apache.poi.hwpf.extractor.WordExtractor;
import org.apache.poi.openxml4j.util.ZipSecureFile;
import org.apache.poi.xwpf.usermodel.*;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URLEncoder;
import java.util.Base64;
import static com.jt.console.service.impl.BaiduOcrServiceImpl.formatOcrResult;
@Service
public class OcrServiceImpl implements OcrService {
@Autowired
private BaiduOcrServiceImpl baiduOcrService;
@Override
public String recognizeText(MultipartFile file) {
String contentType = file.getContentType();
if (contentType == null) {
ServiceAssert.isTrue(false, "文件类型不支持");
return null;
}
InputStream inputStream = null;
try {
inputStream = file.getInputStream();
if (contentType.equals("application/pdf")) {
return extractTextFromPdf(inputStream);
} else if (contentType.equals("application/vnd.openxmlformats-officedocument.wordprocessingml.document") ||
contentType.equals("application/x-tika-ooxml")) {
return extractTextFromDocx(inputStream);
} else if (contentType.equals("application/msword")) {
return extractTextFromDoc(inputStream);
} else {
ServiceAssert.isTrue(false, "不支持的文件类型");
return null;
}
} catch (Exception e) {
e.printStackTrace();
ServiceAssert.isTrue(false, "处理文件出错");
return null;
} finally {
if (inputStream != null) {
try {
inputStream.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
private String extractTextFromPdf(InputStream inputStream) throws IOException {
StringBuilder text = new StringBuilder();
try (PDDocument document = PDDocument.load(inputStream)) {
System.setProperty("org.apache.pdfbox.logging.SILENT", "true");
PDFTextStripper pdfStripper = new PDFTextStripper();
text.append(pdfStripper.getText(document));
}
return text.toString();
}
private String extractTextFromDocx(InputStream inputStream) throws IOException {
StringBuilder text = new StringBuilder();
ZipSecureFile.setMinInflateRatio(0.001);
try (XWPFDocument document = new XWPFDocument(inputStream)) {
document.getParagraphs().forEach(paragraph -> text.append(paragraph.getText()).append("\n"));
for (XWPFTable table : document.getTables()) {
for (XWPFTableRow row : table.getRows()) {
for (XWPFTableCell cell : row.getTableCells()) {
text.append(cell.getText()).append("\t");
}
text.append("\n");
}
}
}
return text.toString();
}
private String extractTextFromDoc(InputStream inputStream) throws IOException {
StringBuilder text = new StringBuilder();
try (HWPFDocument document = new HWPFDocument(inputStream)) {
WordExtractor extractor = new WordExtractor(document);
String[] paragraphs = extractor.getParagraphText();
for (String paragraph : paragraphs) {
text.append(paragraph).append("\n");
}
}
return text.toString();
}
private void extractImagesFromPdf(PDDocument document) throws IOException {
PDPageTree pages = document.getPages();
int imageCounter = 0;
for (PDPage page : pages) {
PDResources resources = page.getResources();
for (COSName xObjectName : resources.getXObjectNames()) {
PDXObject xObject = resources.getXObject(xObjectName);
if (xObject instanceof PDImageXObject) {
PDImageXObject image = (PDImageXObject) xObject;
BufferedImage bufferedImage = image.getImage();
File imageFile = new File("image" + (++imageCounter) + ".png");
try (FileOutputStream fos = new FileOutputStream(imageFile)) {
ImageIO.write(bufferedImage, "PNG", fos);
}
}
}
}
}
public String extractImagesFromDocx(XWPFDocument document, boolean urlEncode) throws IOException {
StringBuilder recognitionResults = new StringBuilder();
int imageCounter = 0;
for (XWPFPictureData pictureData : document.getAllPictures()) {
byte[] bytes = pictureData.getData();
String base64Image = Base64.getEncoder().encodeToString(bytes);
if (urlEncode) {
base64Image = URLEncoder.encode(base64Image, "utf-8");
}
String ocrResult = baiduOcrService.recognizeImage(base64Image);
String formattedResult = formatOcrResult(ocrResult);
recognitionResults.append("Image ").append(++imageCounter).append(": ").append(formattedResult).append("\n");
}
return recognitionResults.toString();
}
}
package com.jt.console.service.impl;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.jt.common.beans.ServiceAssert;
import okhttp3.*;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import java.io.IOException;
import java.net.URLEncoder;
import java.util.Base64;
import java.util.List;
import java.util.Arrays;
@Service("baiduOcrServiceImpl")
public class BaiduOcrServiceImpl {
@Value("${baidu.ocr.apiKey}")
private String API_KEY;
@Value("${baidu.ocr.secretKey}")
private String SECRET_KEY;
private static final List<String> SUPPORTED_FORMATS = Arrays.asList("png", "jpg", "jpeg", "bmp", "gif");
private static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
private String getAccessToken() throws IOException {
MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");
RequestBody body = RequestBody.create(mediaType, "grant_type=client_credentials&client_id=" + API_KEY
+ "&client_secret=" + SECRET_KEY);
Request request = new Request.Builder()
.url("https://aip.baidubce.com/oauth/2.0/token")
.method("POST", body)
.addHeader("Content-Type", "application/x-www-form-urlencoded")
.build();
Response response = HTTP_CLIENT.newCall(request).execute();
if (!response.isSuccessful()) {
String errorMessage = "OCR request failed. Status code: " + response.code() + ", Message: " + response.message();
ServiceAssert.isTrue(false, errorMessage);
}
String responseBody = response.body().string();
JSONObject jsonObject = JSON.parseObject(responseBody);
return jsonObject.getString("access_token");
}
public String recognizeImage(String base64Image) throws IOException {
MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");
RequestBody body = RequestBody.create(mediaType, "image=" + base64Image + "&detect_direction=false¶graph=false&probability=false");
Request request = new Request.Builder()
.url("https://aip.baidubce.com/rest/2.0/ocr/v1/accurate_basic?access_token=" + getAccessToken())
.method("POST", body)
.addHeader("Content-Type", "application/x-www-form-urlencoded")
.addHeader("Accept", "application/json")
.build();
try (Response response = HTTP_CLIENT.newCall(request).execute()) {
if (!response.isSuccessful()) {
String errorMessage = "Failed to obtain access token. Status code: " + response.code() + ", Message: " + response.message();
ServiceAssert.isTrue(false, errorMessage);
}
return formatOcrResult(response.body().string());
}
}
public String convertToBase64(MultipartFile file, boolean urlEncode) throws IOException {
String filename = file.getOriginalFilename();
if (filename == null) {
ServiceAssert.isTrue(false, "文件名为空");
}
String extension = filename.substring(filename.lastIndexOf('.') + 1).toLowerCase();
if (!SUPPORTED_FORMATS.contains(extension)) {
ServiceAssert.isTrue(false, "不支持的图片格式: " + extension);
}
byte[] bytes = file.getBytes();
String base64 = Base64.getEncoder().encodeToString(bytes);
if (urlEncode) {
base64 = URLEncoder.encode(base64, "utf-8");
}
return base64;
}
public static String formatOcrResult(String ocrResult) {
StringBuilder resultText = new StringBuilder();
try {
JSONObject jsonObject = JSON.parseObject(ocrResult);
if (jsonObject.containsKey("words_result")) {
var wordsResult = jsonObject.getJSONArray("words_result");
if (wordsResult != null && !wordsResult.isEmpty()) {
for (int i = 0; i < wordsResult.size(); i++) {
JSONObject wordObject = wordsResult.getJSONObject(i);
String word = wordObject.getString("words");
if (word != null && !word.isEmpty()) {
resultText.append(word).append(" ");
}
}
} else {
return "";
}
} else {
return "";
}
} catch (Exception e) {
ServiceAssert.isTrue(false,e.getMessage());
}
return resultText.toString().trim();
}
}