1. 导入必要的库
javascript
import * as tf from '@tensorflow/tfjs';
import { ImageCaptcha } from 'captcha-generator';
import { plot, Plot } from 'nodeplotlib';
const characters = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ';
const width = 170;
const height = 80;
const n_len = 4;
const n_class = characters.length + 1; // 加1是为了空白类别 更多内容联系1436423940
2. 生成验证码图像
javascript
function generateCaptcha() {
const generator = new ImageCaptcha({ width, height });
const randomStr = Array.from({ length: n_len }, () => characters[Math.floor(Math.random() * characters.length)]).join('');
const img = generator.generate(randomStr);
// 显示图片
plot([{ data: img.getPixels(), type: 'image' }], { title: randomStr });
return { img, randomStr };
}
const { img, randomStr } = generateCaptcha();
3. 定义CTC Loss
javascript
function ctcLoss(y_pred, labels, input_lengths, label_lengths) {
return tf.losses.ctcLoss(labels, y_pred, input_lengths, label_lengths, n_class - 1);
}
4. 构建模型
javascript
function createModel() {
const input = tf.input({ shape: [height, width, 3] });
let x = tf.layers.conv2d({ filters: 32, kernelSize: 3, padding: 'same', activation: 'relu' }).apply(input);
x = tf.layers.maxPooling2d({ poolSize: [2, 2] }).apply(x);
x = tf.layers.conv2d({ filters: 64, kernelSize: 3, padding: 'same', activation: 'relu' }).apply(x);
x = tf.layers.maxPooling2d({ poolSize: [2, 2] }).apply(x);
x = tf.layers.conv2d({ filters: 128, kernelSize: 3, padding: 'same', activation: 'relu' }).apply(x);
x = tf.layers.maxPooling2d({ poolSize: [2, 2] }).apply(x);
const shape = x.shape;
const rnn_length = shape[1];
const rnn_dimen = shape[2] * shape[3];
x = tf.layers.reshape({ targetShape: [rnn_length, rnn_dimen] }).apply(x);
x = tf.layers.gru({ units: 128, returnSequences: true }).apply(x);
x = tf.layers.gru({ units: 128, returnSequences: true }).apply(x);
x = tf.layers.dense({ units: n_class, activation: 'softmax' }).apply(x);
return tf.model({ inputs: input, outputs: x });
}
const model = createModel();
model.compile({
optimizer: tf.train.adam(),
loss: ctcLoss
});
5. 定义数据生成器
javascript
function* dataGenerator(batchSize = 32) {
const generator = new ImageCaptcha({ width, height });
while (true) {
const X = [];
const y = [];
const inputLengths = [];
const labelLengths = [];
for (let i = 0; i < batchSize; i++) {
const randomStr = Array.from({ length: n_len }, () => characters[Math.floor(Math.random() * characters.length)]).join('');
const img = generator.generate(randomStr).getPixels();
X.push(img);
y.push(Array.from(randomStr).map(char => characters.indexOf(char)));
inputLengths.push(rnn_length);
labelLengths.push(n_len);
}
yield { xs: tf.tensor(X, [batchSize, height, width, 3]), ys: tf.tensor(y) };
}
}
const dataGen = dataGenerator();
6. 训练模型
javascript
(async () => {
const { xs, ys } = dataGen.next().value;
await model.fit(xs, ys, {
epochs: 20,
stepsPerEpoch: 400,
validationData: dataGen.next().value,
validationSteps: 20,
callbacks: {
onEpochEnd: (epoch, logs) => {
console.log(`Epoch ${epoch + 1}, Loss: ${logs.loss}, Validation Loss: ${logs.val_loss}`);
}
}
});
})();
7. 评估模型
javascript
复制代码
async function evaluate(model, dataGen) {
let correct = 0;
let total = 0;
for (let i = 0; i < 10; i++) {
const { xs, ys } = dataGen.next().value;
const y_pred = model.predict(xs);
const decoded = tf.editDistance(y_pred, ys);
for (let j = 0; j < decoded.shape[0]; j++) {
if (decoded[j].every((v, k) => v === ys[j][k])) {
correct++;
}
total++;
}
}
const accuracy = (correct / total) * 100;
console.log(`Accuracy: ${accuracy.toFixed(2)}%`);
}
evaluate(model, dataGen);