概念
用于处理序列问题:翻译(N vs N)、信息提取(N vs 1)、生成(1 vs N)。
RNN 要求输入队列和输出队列等长,Seq2Seq 可以解决输入队列与输出队列不等长的问题。
实验(验证码识别)
数据集:生成 4 位数字的验证码图片(测试集和训练集各 1000 张),图片名称为 index.code.jpg,截取 code 作为标签。
网络结构:
- 编码:全连接 + 标准化(BN)+ 激活(ReLU)+ LSTM。
- 解码:LSTM + 全连接 + softmax(多分类)。
优化器:Adam。
损失函数:均方差(MSELoss)。
输出:4 个 one-hot 类型,结果为最大的索引值。
生成验证码
import random
from PIL import Image, ImageDraw, ImageFont
# 随机数字
def rand_char():
return chr(random.randint(48, 57))
# 随机背景颜色
def rand_bg():
return (random.randint(50, 150), random.randint(50, 150), random.randint(50, 150))
# 随机数字颜色
def rand_color():
return (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))
width = 240
height = 60
font = ImageFont.truetype("arial.ttf", size=36)
for i in range(1000):
img = Image.new("RGB", (width, height), (255, 255, 255))
draw = ImageDraw.ImageDraw(img)
# 画背景
for x in range(width):
for y in range(height):
draw.point((x, y), rand_bg())
# 写数字
chrs = []
for n in range(4):
each = rand_char()
chrs.append(each)
draw.text((n * 60 + 10, 10