🎯 图像描述生成器(Image Captioning)
🔍 给定一张图像,让模型自动输出一段自然语言描述,类似于 “这是一只在沙滩上奔跑的金毛”。
本项目将构建一个图像编码 + Transformer 解码的双模块模型,完成从训练到推理与部署的完整流程。
✅ 项目亮点
- 使用 预训练 CNN(ResNet50)提取图像特征
- 使用 Transformer Decoder 生成文本描述
- 支持自回归推理 + Greedy 或 Beam Search
- 可部署为 Flask API、TF Serving 或 TF.js
- 可拓展用于图文匹配、视觉问答、AIGC 模块
📦 数据准备
使用 [Flickr8k / Flickr30k / MS COCO Captioning 数据集]
示例格式:
图片 | 描述(多条) |
---|---|
123.jpg | [“A dog running on the grass.”, “The puppy is playing.”] |
✅ 文本预处理
- 构建词表(使用
Tokenizer
) - 特殊 token:
<start>
、<end>
、<pad>
、<unk>
- 最大长度 padding(如
max_len=25
)
from tensorflow.keras.preprocessing.text import Tokenizer
tokenizer = Tokenizer(num_words=10000, oov_token='<unk>')
tokenizer.fit_on_texts(caption_texts)
seqs = tokenizer.texts_to_sequences(caption_texts)
padded_seqs = tf.keras.preprocessing.sequence.pad_sequences(seqs, padding='post', maxlen=25)
🧠 模型结构
1. 图像编码器(CNN)
def build_cnn_encoder():
base = tf.keras.applications.ResNet50(include_top=False, weights='imagenet')
base.trainable = False
output = base.get_layer('conv5_block3_out').output # shape: (7, 7, 2048)
model = tf.keras.Model(inputs=base.input, outputs=output)
return model
- 输出形状:
[batch, 49, 2048]
(展平空间维度)
2. Transformer Decoder(简化版)
class CaptionDecoder(tf.keras.Model):
def __init__(self, vocab_size, d_model, num_heads, num_layers, max_len):
super().__init__()
self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)
self.pos_encoding = positional_encoding(max_len, d_model)
self.dec_layers = [DecoderLayer(d_model, num_heads, d_model*4) for _ in range(num_layers)]
self.final = tf.keras.layers.Dense(vocab_size)
def call(self, x, enc_output, mask):
x = self.embedding(x) + self.pos_encoding[:, :tf.shape(x)[1], :]
for layer in self.dec_layers:
x = layer(x, enc_output, mask)
return self.final(x)
🔁 训练流程
训练目标:
输入:
- 图像特征 (from CNN)
- 已有 caption token(如 "<start> a dog is")
目标:
- 下一 token(如 "running")
✅ Loss 计算(带 mask)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
def masked_loss(y_true, y_pred):
loss = loss_object(y_true, y_pred)
mask = tf.cast(tf.not_equal(y_true, 0), tf.float32)
return tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
📈 推理流程(逐 token 生成)
def generate_caption(img, tokenizer, cnn, decoder, max_len=25):
features = cnn(img) # shape: [1, 49, 2048]
output = [tokenizer.word_index['<start>']]
for i in range(max_len):
seq = tf.constant([output])
logits = decoder(seq, features, mask=None)
next_token = tf.argmax(logits[:, -1, :], axis=-1).numpy()[0]
output.append(next_token)
if tokenizer.index_word.get(next_token) == '<end>':
break
return ' '.join([tokenizer.index_word.get(i, '') for i in output[1:-1]])
🧪 示例推理输出
输入图片: 🐕 在草地上奔跑的狗
输出文本: "A brown dog is running through the field."
🗃️ 项目结构建议
image-captioning/
├── data/
│ ├── images/
│ └── captions.csv
├── models/
│ ├── encoder.py
│ ├── decoder.py
├── train.py
├── inference.py
├── app.py (Flask API)
├── export/
│ ├── saved_model/
│ ├── caption_model.tflite
└── requirements.txt
🚀 模型导出与部署
✅ 保存模型(SavedModel)
tf.saved_model.save(caption_model, "export/saved_model")
✅ TFLite 转换(不含 CNN,仅解码器)
converter = tf.lite.TFLiteConverter.from_saved_model("export/saved_model")
tflite_model = converter.convert()
✅ TF.js 转换(可部署到浏览器)
tensorflowjs_converter \
--input_format=tf_saved_model \
export/saved_model tfjs_model/
✅ Flask 接口推理
@app.route('/caption', methods=['POST'])
def caption():
img = request.files['file']
tensor = preprocess(img)
text = generate_caption(tensor, tokenizer, cnn_model, decoder)
return jsonify({"caption": text})
💡 应用场景拓展
- 图像搜索系统
- 图文匹配(匹配图像与文案)
- 视觉问答 VQA 模型前置模块
- 多模态大模型(图像 + 文本预训练)
✅ 项目小结
模块 | 技术 |
---|---|
特征提取 | CNN(ResNet50) |
文本生成 | Transformer Decoder |
数据处理 | tf.data , Tokenizer , padding |
部署支持 | TF Serving / TFLite / TF.js |
应用方向 | 图像理解、AIGC 前置、AI 辅助写作 |