利用深度学习实现验证码识别-2-使用Python导出ONNX模型并在Java中调用实现验证码识别

在这里插入图片描述

1. Python部分:导出ONNX模型

首先,我们需要在Python中定义并导出一个已经训练好的验证码识别模型。以下是完整的Python代码:

import string
import torch
import torch.nn as nn
import torch.nn.functional as F

CHAR_SET = string.digits

# 优化后的模型设计
class CaptchaModel(nn.Module):
    def __init__(self):
        super(CaptchaModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.fc1 = nn.Linear(128 * 5 * 12, 256)  # 调整为实际展平维度
        self.fc2 = nn.Linear(256, 4 * len(CHAR_SET))
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = F.relu(F.max_pool2d(self.conv3(x), 2))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x.view(-1, 4, len(CHAR_SET))

# 使用CUDA,如果可用的话
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 假设你的模型已经训练好并保存在 'best_model.pth'
model = CaptchaModel().to(device)
model.load_state_dict(torch.load('best_model.pth'))

# 生成一个测试输入 (示例输入的形状应与模型输入形状一致)
dummy_input = torch.randn(1, 1, 40, 100).to(device)

# 导出模型为 ONNX 格式
torch.onnx.export(model, dummy_input, "captcha_model.onnx", 
                  input_names=["input"], output_names=["output"], 
                  dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}})

print("Model exported to captcha_model.onnx")

这段代码定义了一个验证码识别模型,并将其导出为ONNX格式,以便在Java中使用。

2. Java部分:调用ONNX模型进行验证码识别

接下来,我们使用Java调用导出的ONNX模型进行验证码识别。以下是完整的Java代码:

  • 引用onnxruntime-1.19.0.jar
package com.tushuoit;

import ai.onnxruntime.*;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import java.util.List;

public class CaptchaInference {
    private static final String CHAR_SET = "0123456789";
    private static final int INPUT_WIDTH = 100;
    private static final int INPUT_HEIGHT = 40;
    private static final Random random = new Random();

    public static void main(String[] args) throws Exception {
        // 随机生成4个字符的验证码文本
        String captchaText = generateRandomText(4);
        System.out.println("Generated Captcha Text: " + captchaText);

        // 生成包含文本的Bitmap (BufferedImage)
        BufferedImage captchaImage = generateCaptcha(captchaText, 36, INPUT_WIDTH, INPUT_HEIGHT);

        // 将Bitmap保存为文件(仅用于查看生成的图像,实际使用中可以省略)
        ImageIO.write(captchaImage, "png", new File("generated_captcha.png"));

        // 将图像转换为浮点数数组,并进行归一化处理
        float[] inputData = imageToFloatArray(captchaImage);

        // 创建ONNX Runtime环境
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions opts = new OrtSession.SessionOptions();

        // 加载ONNX模型
        OrtSession session = env.createSession("captcha_model.onnx", opts);

        // 创建输入张量
        FloatBuffer inputBuffer = FloatBuffer.wrap(inputData);
        OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputBuffer,
                new long[] { 1, 1, INPUT_HEIGHT, INPUT_WIDTH });

        // 进行推理
        OrtSession.Result result = session.run(Collections.singletonMap("input", inputTensor));

        // Extract output tensor and decode it
        float[][][] outputData = (float[][][]) result.get(0).getValue();
        List<String> decodedTexts = decodeOutput(outputData);

        // Print the decoded captcha text
        for (String text : decodedTexts) {
            System.out.println("Predicted Captcha Text: " + text);
        }

        System.out.println("Inference completed.");
        // 释放资源
        session.close();
        env.close();
    }

    // 随机生成指定长度的验证码文本
    private static String generateRandomText(int length) {
        StringBuilder text = new StringBuilder(length);
        for (int i = 0; i < length; i++) {
            text.append(CHAR_SET.charAt(random.nextInt(CHAR_SET.length())));
        }
        return text.toString();
    }

    // 生成包含文本的BufferedImage
    private static BufferedImage generateCaptcha(String text, int fontSize, int width, int height) {
        BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
        Graphics2D g2d = image.createGraphics();

        // 设置背景颜色为白色
        g2d.setColor(Color.WHITE);
        g2d.fillRect(0, 0, width, height);

        // 设置字体和颜色
        g2d.setFont(new Font("DroidSansMono", Font.PLAIN, fontSize));
        g2d.setColor(Color.BLACK);

        // 绘制文本
        FontMetrics fm = g2d.getFontMetrics();
        int x = 5; // 文字开始的X坐标
        int y = fm.getAscent() + 5; // 文字开始的Y坐标
        g2d.drawString(text, x, y);

        g2d.dispose();
        return image;
    }

    // 将BufferedImage转换为float数组,并进行归一化处理
    private static float[] imageToFloatArray(BufferedImage image) {
        int width = image.getWidth();
        int height = image.getHeight();
        float[] floatArray = new float[width * height];

        for (int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                int rgb = image.getRGB(x, y);
                int gray = (rgb >> 16) & 0xFF; // 因为是灰度图,只需获取一个通道的值
                floatArray[y * width + x] = (gray / 255.0f - 0.5f) * 2.0f; // 归一化到[-1, 1]
            }
        }

        return floatArray;
    }

    private static List<String> decodeOutput(float[][][] outputData) {
        List<String> decodedTexts = new ArrayList<>();
        for (float[][] singleOutput : outputData) {
            StringBuilder decodedText = new StringBuilder();
            for (float[] charProbabilities : singleOutput) {
                int maxIndex = getMaxIndex(charProbabilities);
                decodedText.append(CHAR_SET.charAt(maxIndex));
            }
            decodedTexts.add(decodedText.toString());
        }
        return decodedTexts;
    }

    private static int getMaxIndex(float[] probabilities) {
        int maxIndex = 0;
        float maxProb = probabilities[0];
        for (int i = 1; i < probabilities.length; i++) {
            if (probabilities[i] > maxProb) {
                maxProb = probabilities[i];
                maxIndex = i;
            }
        }
        return maxIndex;
    }
}

这段Java代码首先生成一个随机的验证码图像,然后将其转换为模型输入格式,并通过ONNX Runtime调用导出的模型进行推理,最后解码模型的输出以获取识别的验证码文本。
在这里插入图片描述

总结

通过上述步骤,我们成功地在Python中导出了一个验证码识别模型,并在Java中调用该模型进行验证码识别。这种方法充分利用了Python在深度学习模型训练和导出方面的优势,以及Java在实际应用部署和性能方面的优势,实现了高效的验证码识别系统。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

@井九

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值