验证码识别是一项实用的图像识别任务。虽然深度学习主流工具大多基于 Python,但在生产环境中,Go 语言以其高性能和简洁性,常用于构建后端服务。本文将演示如何使用 Go 加载 TensorFlow 模型并实现验证码图片识别。

1. 准备模型(Python 部分)

由于 Go 本身不擅长训练模型,我们可以用 Python 和 TensorFlow/Keras 训练一个 CNN 模型,并保存为 .pb 格式。

Python 示例代码(训练后保存模型):

# train_and_export.py
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
import tensorflow as tf

model = Sequential([
    Conv2D(32, (3,3), activation='relu', input_shape=(60,160,1)),
    MaxPooling2D(2,2),
    Conv2D(64, (3,3), activation='relu'),
    MaxPooling2D(2,2),
    Flatten(),
    Dense(128, activation='relu'),
    Dense(4 * 36, activation='softmax')  # 4位验证码,每位36类(0-9A-Z)
])
更多内容访问ttocr.com或联系1436423940
model.compile(optimizer='adam', loss='categorical_crossentropy')
# 假设你已经完成训练...

# 保存模型
tf.saved_model.save(model, "saved_model/captcha_model")
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.

注意:这里保存的是 TensorFlow 2.x 的 SavedModel 格式,Go 可直接读取。


2. 安装 Go TensorFlow 库

go get github.com/tensorflow/tensorflow/tensorflow/go
  • 1.

确保你已安装 TensorFlow C 库(例如使用 conda 或 apt 安装)。


3. 加载模型并预测验证码

package main

import (
	"fmt"
	"image"
	"io/ioutil"
	"os"

	tf "github.com/tensorflow/tensorflow/tensorflow/go"
	"github.com/nfnt/resize"
	_ "image/png"
)

func preprocessImage(imagePath string) (*tf.Tensor, error) {
	file, err := os.Open(imagePath)
	if err != nil {
		return nil, err
	}
	defer file.Close()

	img, _, err := image.Decode(file)
	if err != nil {
		return nil, err
	}

	resized := resize.Resize(160, 60, img, resize.Lanczos3)
	// 转换为灰度 + float32 + 正则化
	data := make([]float32, 160*60)
	i := 0
	for y := 0; y < 60; y++ {
		for x := 0; x < 160; x++ {
			r, g, b, _ := resized.At(x, y).RGBA()
			gray := float32((r + g + b) / 3 >> 8)
			data[i] = gray / 255.0
			i++
		}
	}

	// 构建 Tensor
	tensor, err := tf.NewTensor([1][60][160][1]float32{})
	if err != nil {
		return nil, err
	}
	copy(tensor.Value().([1][60][160][1]float32)[0][:][0][:], data)
	return tensor, nil
}

func main() {
	model, err := tf.LoadSavedModel("saved_model/captcha_model", []string{"serve"}, nil)
	if err != nil {
		panic(err)
	}
	defer model.Session.Close()

	input, err := preprocessImage("captcha_samples/A7K9_1.png")
	if err != nil {
		panic(err)
	}

	result, err := model.Session.Run(
		map[tf.Output]*tf.Tensor{
			model.Graph.Operation("serving_default_input_1").Output(0): input,
		},
		[]tf.Output{
			model.Graph.Operation("StatefulPartitionedCall").Output(0),
		},
		nil,
	)
	if err != nil {
		panic(err)
	}

	// 输出预测
	prediction := result[0].Value().([][]float32)[0]
	alphabet := "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
	fmt.Print("Predicted: ")
	for i := 0; i < 4; i++ {
		// 每个字符有36个概率值
		offset := i * 36
		maxIdx := 0
		maxVal := float32(0)
		for j := 0; j < 36; j++ {
			if prediction[offset+j] > maxVal {
				maxVal = prediction[offset+j]
				maxIdx = j
			}
		}
		fmt.Print(string(alphabet[maxIdx]))
	}
	fmt.Println()
}
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.
  • 74.
  • 75.
  • 76.
  • 77.
  • 78.
  • 79.
  • 80.
  • 81.
  • 82.
  • 83.
  • 84.
  • 85.
  • 86.
  • 87.
  • 88.
  • 89.
  • 90.
  • 91.