DL练习4—基于LSTM的验证码识别

一,基本版本

1.1 训练图片示例:

1.2 代码:

#-*- coding:utf-8 -*
import tensorflow as tf
import os
import random
import numpy as np
from PIL import Image

path = os.getcwd()  #项目所在路径

captcha_path = path + '/train_data'  #训练集-验证码所在路径
validation_path = path + '/validation_data' #验证集-验证码所在路径
test_data_path = path + '/test_data'    #测试集-验证码文件存放路径
output_path = path + '/result/result.txt'   #测试结果存放路径
model_path = path + '/model/model.ckpt' #模型存放路径

#要识别的字符
number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

batch_size = 64  #size of batch
time_steps = 26   #unrolled through 28 time steps #每个time_step是图像的一行像素 height
n_input = 80  #rows of 28 pixels  #width
image_channels = 1  # 图像的通道数
captcha_num = 4 # 验证码中字符个数
n_classes = len(number) + len(ALPHABET)    #类别分类

learning_rate = 0.001   #learning rate for adam
num_units = 128   #hidden LSTM units
layer_num = 2   #网络层数
iteration = 10000   #训练迭代次数

def computational_graph_lstm(x, y, batch_size = batch_size):

    #weights and biases of appropriate shape to accomplish above task
    out_weights = tf.Variable(tf.random_normal([num_units,n_classes]), name = 'out_weight')
    out_bias = tf.Variable(tf.random_normal([n_classes]),name = 'out_bias')

    #构建网络
    lstm_layer = [tf.nn.rnn_cell.LSTMCell(num_units, state_is_tuple=True) for _ in range(layer_num)]    #创建两层的lstm
    mlstm_cell = tf.nn.rnn_cell.MultiRNNCell(lstm_layer, state_is_tuple = True)   #将lstm连接在一起

    init_state = mlstm_cell.zero_state(batch_size, tf.float32)  #cell的初始状态

    outputs = list()    #每个cell的输出
    state = init_state
    with tf.variable_scope('RNN'):
        for timestep in range(time_steps):
            if timestep > 0:
                tf.get_variable_scope().reuse_variables()
            (cell_output, state) = mlstm_cell(x[:, timestep, :], state) # 这里的state保存了每一层 LSTM 的状态
            outputs.append(cell_output)
    # h_state = outputs[-1] #取最后一个cell输出

    #计算输出层的第一个元素
    prediction_1 = tf.nn.softmax(tf.matmul(outputs[-4],out_weights)+out_bias)    #获取最后time-step的输出,使用全连接, 得到第一个验证码输出结果
    #计算输出层的第二个元素
    prediction_2 = tf.nn.softmax(tf.matmul(outputs[-3],out_weights)+out_bias)   #输出第二个验证码预测结果
    #计算输出层的第三个元素
    prediction_3 = tf.nn.softmax(tf.matmul(outputs[-2],out_weights)+out_bias)   #输出第三个验证码预测结果
    #计算输出层的第四个元素
    prediction_4 = tf.nn.softmax(tf.matmul(outputs[-1],out_weights)+out_bias)   #输出第四个验证码预测结果,size:[batch,num_class]
    #输出连接
    prediction_all = tf.concat([prediction_1, prediction_2, prediction_3, prediction_4],1)  # 4 * [batch, num_class] => [batch, 4 * num_class]
    prediction_all = tf.reshape(prediction_all,[batch_size, captcha_num, n_classes],name ='prediction_merge') # [4, batch, num_class] => [batch, 4, num_class]

    #loss_function
    loss = -tf.reduce_mean(y * tf.log(prediction_all),name = 'loss')
    # loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction_all,labels=y))
    #optimization
    opt = tf.train.AdamOptimizer(learning_rate=learning_rate, name = 'opt').minimize(loss)
    #model evaluation
    pre_arg = tf.argmax(prediction_all,2,name = 'predict')
    y_arg = tf.argmax(y,2)
    correct_prediction = tf.equal(pre_arg, y_arg)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32),name = 'accuracy')

    return opt, loss, accuracy, pre_arg, y_arg

#获取bacth_size数据集
def get_batch(data_path = captcha_path, is_training = True):
    target_file_list = os.listdir(data_path)    #读取路径下的所有文件名

    batch = batch_size if is_training else len(target_file_list)   # 确认batch 大小
    batch_x = np.zeros([batch, time_steps, n_input])   #batch 数据
    batch_y = np.zeros([batch, captcha_num, n_classes])   # batch 标签

    for i in range(batch):
        file_name = random.choice(target_file_list) if is_training else target_file_list[i] #确认要打开的文件名
        img = Image.open(data_path + '/' + file_name) #打开图片
        img = np.array(img)
        if len(img.shape) > 2: #彩色图
            img = np.mean(img, -1)  #转换成灰度图像:(26,80,3) =>(26,80)
            img = img / 255   #标准化,为了防止训练集的方差过大而导致的收敛过慢问题。
            # img = np.reshape(img,[time_steps,n_input])  #转换格式:(2080,) => (26,80)
        batch_x[i] = img

        label = np.zeros(captcha_num * n_classes)
        for num, char in enumerate(file_name.split('.')[0]):
            index = num * n_classes + char2index(char)
            label[index] = 1
        label = np.reshape(label,[captcha_num, n_classes])
        batch_y[i] = label
    return batch_x, batch_y

#字符转换成000100
def char2index(c):
    k = ord(c)
    index = -1
    if k >= 48 and k <= 57: #数字索引
        index = k - 48
    if k >= 65 and k <= 90: #大写字母索引
        index = k - 55
    if k >= 97 and k <= 122: #小写字母索引
        index = k - 61
    if index == -1:
        raise ValueError('No Map')
    return index

#000100转换成字符
def index2char(k):
    # k = chr(num)
    index = -1
    if k >= 0 and k < 10: #数字索引
        index = k + 48
    if k >= 10 and k < 36: #大写字母索引
        index = k + 55
    if k >= 36 and k < 62: #小写字母索引
        index = k + 61
    if index == -1:
        raise ValueError('No Map')
    return chr(index)

#训练
def train():
    # defining placeholders
    x = tf.placeholder("float",[None,time_steps,n_input], name = "x") #input image placeholder
    y = tf.placeholder("float",[None,captcha_num,n_classes], name = "y")  #input label placeholder

    # computational graph
    opt, loss, accuracy, pre_arg, y_arg = computational_graph_lstm(x, y)

    saver = tf.train.Saver()  # 创建训练模型保存类
    init = tf.global_variables_initializer()    #初始化变量值

    with tf.Session() as sess:  # 创建tensorflow session
        sess.run(init)
        iter = 1
        while iter < iteration:
            batch_x, batch_y = get_batch()
            sess.run(opt, feed_dict={x: batch_x, y: batch_y})   #只运行优化迭代计算图
            if iter %100==0:
                los, acc, parg, yarg = sess.run([loss, accuracy, pre_arg, y_arg],feed_dict={x:batch_x,y:batch_y})
                print("For iter ",iter)
                print("Accuracy ",acc)
                print("Loss ",los)
                if iter % 1000 ==0:
                    print("predict arg:",parg[0:10])
                    print("yarg :",yarg[0:10])
                print("__________________")
                # if acc > 0.95:
                #     print("training complete, accuracy:", acc)
                #     break
            if iter % 1000 == 0:   #保存模型
                saver.save(sess, model_path, global_step=iter)
            iter += 1
        # 计算验证集准确率
        valid_x, valid_y = get_batch(data_path=validation_path, is_training=False)
        print("Validation Accuracy:", sess.run(accuracy, feed_dict={x: valid_x, y: valid_y}))

#获取测试集
def get_test_set():
    target_file_list = os.listdir(test_data_path)   #获取测试集路径下的所有文件
    print("预测的验证码文件:",len(target_file_list))

    #判断条件
    flag = len(target_file_list) // batch_size  #计算待检测验证码个数能被batch size 整除的次数
    batch_len = flag if flag > 0 else 1  #共有多少个batch
    flag2 = len(target_file_list) % batch_size  #计算验证码被batch size整除后的取余
    batch_len = batch_len if flag2 == 0 else batch_len + 1  #若不能整除,则batch数量加1

    print("共生成batch数:",batch_len)
    print("验证码根据batch取余:",flag2)

    batch =  np.zeros([batch_len * batch_size, time_steps, n_input])
    for i, file in enumerate(target_file_list):
        batch[i] = open_iamge(file)
    batch = batch.reshape([batch_len, batch_size, time_steps, n_input])
    return batch, target_file_list #batch_file_name

#打开图像
def open_iamge(file):
    img = Image.open(test_data_path + '/' + file) #打开图片
    img = np.array(img)
    if len(img.shape) > 2:
        img = np.mean(img, -1)  #转换成灰度图像:(26,80,3) =>(26,80)
        img = img / 255
    return img

#预测
def predict():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(path + "/model/" + "model.ckpt-5000.meta")
        saver.restore(sess, tf.train.latest_checkpoint(path + "/model/")) #读取已训练模型

        graph = tf.get_default_graph()  #获取原始计算图,并读取其中的tensor
        x = graph.get_tensor_by_name("x:0")
        y = graph.get_tensor_by_name("y:0")
        pre_arg = graph.get_tensor_by_name("predict:0")

        test_x, file_list = get_test_set()  #获取测试集
        predict_result = []
        for i in range(len(test_x)):
            batch_test_x = test_x[i]
            batch_test_y = np.zeros([batch_size, captcha_num,n_classes])    #创建空的y输入
            test_predict = sess.run([pre_arg], feed_dict={x: batch_test_x, y:batch_test_y})
            # print(test_predict)
            # predict_result.extend(test_predict)

            for line in test_predict[0]:    #将预测结果转换为字符
                character = ""
                for each in line:
                    character += index2char(each)
                predict_result.append(character)

        predict_result = predict_result[:len(file_list)]    #预测结果
        write_to_file(predict_result, file_list)    #保存到文件

#写入文档
def write_to_file(predict_list, file_list):
    with open(output_path, 'a') as f:
        for i, res in enumerate(predict_list):
            if i == 0:
                f.write("id\tfile\tresult\n")
            f.write(str(i) + "\t" + file_list[i] + "\t" + res + "\n")
    print("预测结果保存在:",output_path)

#训练
train()

#预测
predict()

1.3 结果: 

 

二,改进版本

       改进版本改成了实时生成图片,不用读取本地图片,无限生成数据集。代码改动还是比较大的,有兴趣可以对比一下。同时添加了一些功能,比如:可视化指数下降学习率断点续训等等

2.1 训练图片示例:

     可以根据个人口味酌情更改验证码生成样式:

2.2 代码:

2.2.1 验证码生成代码一:

#encoding=utf-8
import random
# import matplotlib.pyplot as plt
import string
import sys
import math
from PIL import Image,ImageDraw,ImageFont,ImageFilter
filename="./My_captcha/"
#字体的位置,不同版本的系统会有不同BuxtonSketch.ttf
font_path = 'C:/Windows/Fonts/Georgia.ttf'
#font_path = 'C:/Windows/Fonts/默陌肥圆手写体.ttf'
#生成几位数的验证码
number = 4
#生成验证码图片的高度和宽度
size = (80,26)
#背景颜色,默认为白色
bgcolor = (255,255,255)
#字体颜色,默认为蓝色
fontcolor = (0,0,0)
#干扰线颜色。默认为红色
linecolor = (0,0,0)
#是否要加入干扰线
draw_line = True
#加入干扰线条数的上下限
line_number = (1,5)

#用来随机生成一个字符串
def gene_text():
    # source = list(string.letters)
    # for index in range(0,10):
    #     source.append(str(index))
    source = ['0','1','2','3','4','5','6','7','8','9','A', 'B', 'C', 'D', 'E', 'F', 'G', 'H','I','J', 'K','L', 'M', 'N','O','P','Q','R',
              'S', 'T', 'U', 'V', 'W', 'Z','X', 'Y']
    # source = [ 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H','I','J', 'K','L', 'M', 'N','O','P','Q','R',
    #           'S', 'T', 'U', 'V', 'W', 'Z','X', 'Y']
    return ''.join(random.sample(source,number))#number是生成验证码的位数
#用来绘制干扰线
def gene_line(draw,width,height):
    # begin = (random.randint(0, width), random.randint(0, height))
    # end = (random.randint(0, width), random.randint(0, height))
    begin = (0, random.randint(0, height))
    end = (74, random.randint(0, height))
    draw.line([begin, end], fill = linecolor,width=3)

#生成验证码
def gene_code():
    width,height = size #宽和高
    image = Image.new('RGBA',(width,height),bgcolor) #创建图片
    font = ImageFont.truetype(font_path,25) #验证码的字体
    draw = ImageDraw.Draw(image)  #创建画笔
    text = gene_text() #生成字符串
    font_width, font_height = font.getsize(text)
    draw.text(((width - font_width) / number, (height - font_height) / number),text,\
            font= font,fill=fontcolor) #填充字符串
    if draw_line:
        gene_line(draw,width,height)
    image = image.transform((width+30,height+10), Image.AFFINE, (1,-0.3,0,-0.1,1,0),Image.BILINEAR)  #创建扭曲
    # image = image.transform((width+20,height+10), Image.AFFINE, (1,-0.3,0,-0.1,1,0),Image.BILINEAR)  #创建扭曲
    image = image.filter(ImageFilter.EDGE_ENHANCE_MORE) #滤镜,边界加强
    # a = str(m)
    aa = str(".png")
    path = filename + text + aa
    # cv2.imwrite(path, I1)
    # image.save('idencode.jpg') #保存验证码图片
    image.save(path)


x=1
# if __name__ == "__main__":
# for k in(1,1000):
while x<10:
     gene_code()
     x+=1
     if(x%100==0):
        print("Iter:%d" % x)

2.2.2 验证码生成代码2:

from captcha.image import ImageCaptcha  # pip install captcha
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import random

# 验证码中的字符, 就不用汉字了
number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u',
            'v', 'w', 'x', 'y', 'z']
ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U',
            'V', 'W', 'X', 'Y', 'Z']


# 验证码一般都无视大小写;验证码长度4个字符
def random_captcha_text(char_set=number+ ALPHABET, captcha_size=4):
    captcha_text = []
    for i in range(captcha_size):
        c = random.choice(char_set)
        captcha_text.append(c)
    return captcha_text


# 生成字符对应的验证码
def gen_captcha_text_and_image():
    image = ImageCaptcha()

    captcha_text = random_captcha_text()
    captcha_text = ''.join(captcha_text)

    captcha = image.generate(captcha_text)
    captcha_image = Image.open(captcha)

    captcha_image = np.array(captcha_image)  #转成np.array
    captcha_image=Image.fromarray(np.uint8(captcha_image)) #转成PIL Image
    captcha_image = captcha_image.resize((80, 26), Image.ANTIALIAS) #缩放
    captcha_image.save("./imgs/"+captcha_text + '.jpg') #存图
    # captcha_image.write(captcha_text, "./imgs/"+captcha_text + '.jpg')  # 写到文件
    return captcha_text, captcha_image


if __name__ == '__main__':
    # 测试
    for i in range(10000):
        text, image = gen_captcha_text_and_image()
        if(i%100==0):
            print("生成第%s张图" % i)


    # f = plt.figure()
    # ax = f.add_subplot(111)
    # ax.text(0.1, 0.9, text, ha='center', va='center', transform=ax.transAxes)
    # plt.imshow(image)
    #
    # plt.show()

2.2.3 模型训练整体代码:

#-*- coding:utf-8 -*
import os
import random
import captcha
import numpy as np
import tensorflow as tf
from captcha.image import ImageCaptcha  # pip install captcha
from PIL import Image,ImageDraw,ImageFont,ImageFilter

#########全局变量###########################################
path = os.getcwd()  #项目所在路径
output_path = path + '/result/result.txt'   #测试结果存放路径
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "LSTM_Captcha"    # 保存模型名称

#要识别的字符
number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

batch_size = 64     # size of batch
time_steps = 26     # 每个time_step是图像的一行像素 height
n_input = 80        # rows of 28 pixels  #width
image_channels = 1  # 图像的通道数
captcha_num = 4     # 验证码中字符个数
n_classes = len(number) + len(ALPHABET)    #类别分类

learning_rate = 0.001   #learning rate for adam
decaystep = 5000  # 实现衰减的频率
decay_rate = 0.5  # 衰减率
num_units = 64   #hidden LSTM units
layer_num = 2   #网络层数
iteration = 20000   #训练迭代次数

#自动生成图像
IMAGE_HEIGHT = 26     # 图像高
IMAGE_WIDTH = 80      # 图像宽
#生成验证码图片的宽度和高度
size = (IMAGE_WIDTH,IMAGE_HEIGHT)
#背景颜色,默认为白色
bgcolor = (255,255,255)
#字体颜色,默认为黑色
fontcolor = (0,0,0)
#字体的位置,不同版本的系统会有不同BuxtonSketch.ttf
font_path = 'C:/Windows/Fonts/Georgia.ttf'
#########全局变量###########################################

# 随机生成4个数字+大小写字母的数组
def random_captcha_text(char_set=number+ALPHABET, captcha_size=4):  #数字
    captcha_text = []  # 初始化一个空列表
    for i in range(captcha_size):  # 产生字符的个数
        c = random.choice(char_set)  # 随机产生数字
        captcha_text.append(c)  # 加入列表
    return ''.join(captcha_text)  # 返回生成的字符

# 随机生成4个数字的图片
def gen_captcha_text_and_image():
    width,height = size #宽和高
    image = Image.new('RGBA',(width,height),bgcolor) #创建图片
    font = ImageFont.truetype(font_path,25) #验证码的字体
    draw = ImageDraw.Draw(image)  #创建画笔
    captcha_text = random_captcha_text()  # 随机生成4个数字的数组
    font_width, font_height = font.getsize(captcha_text) #字体大小
    draw.text(((width - font_width) / captcha_num, (height - font_height) / captcha_num),\
              captcha_text,font= font,fill=fontcolor) #填充字符串
    image = image.filter(ImageFilter.EDGE_ENHANCE_MORE)  # 滤镜,边界加强
    # aa = str(".png")
    # path = "./" + captcha_text + aa
    # image.save(path)
    captcha_image = np.array(image)  # 转化成array数组
    return captcha_text, captcha_image

# 转换成灰度图
def convert2gray(img):
    if len(img.shape) > 2:
        gray = np.mean(img, -1)
        # 上面的转法较快,正规转法如下
        # r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
        # gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
        return gray
    else:
        return img

# 字符串转换成0000100的数组
def text2vec(text):
    text_len = len(text)
    if text_len > captcha_num:
        raise ValueError('验证码最长4个字符')

    vector = np.zeros(captcha_num*n_classes)

    def char2pos(c):
        if c == '_':
            k = 62
            return k
        k = ord(c) - 48
        if k > 9:
            k = ord(c) - 55
            if k > 35:
                k = ord(c) - 61
                if k > 61:
                    raise ValueError('No Map')
        return k

    for i, c in enumerate(text):
        idx = i * n_classes + char2pos(c)
        vector[idx] = 1
    return vector

# 0000100的数组转换成字符串
def vec2text(vec):
    char_pos = vec.nonzero()[0]
    text = []
    for i, c in enumerate(char_pos):
        char_at_pos = i  # c/63
        char_idx = c % n_classes
        if char_idx < 10:
            char_code = char_idx + ord('0')
        elif char_idx < 36:
            char_code = char_idx - 10 + ord('A')
        elif char_idx < 62:
            char_code = char_idx - 36 + ord('a')
        elif char_idx == 62:
            char_code = ord('_')
        else:
            raise ValueError('error')
        text.append(chr(char_code))
    return "".join(text)

# [22,32,1,5]类型转换成字符
def index2char(vec):
    text=[]
    chr=''
    for i in range(len(vec[0])):
        subVec=vec[0][i]
        listChr=[]
        for id in range(captcha_num):
            if subVec[id]<10:
                chr=number[subVec[id]]
                listChr.append(chr)
            elif subVec[id]<36:
                chr=ALPHABET[subVec[id]-10]
                listChr.append(chr)
            elif subVec[id] < 62:
                chr = ALPHABET[subVec[id] - 36]
                listChr.append(chr)
            elif subVec[id] == 62:
                listChr.append('_')
            else:
                raise ValueError('error')
        str=''.join(listChr)
        text.append(str)
    return text

# 产生用于训练的bacth_size0大小的数据集
def get_next_batch(batch_size0=64):
    batch_x = np.zeros([batch_size0, time_steps, n_input])
    batch_y = np.zeros([batch_size0, captcha_num, n_classes])

    # 内部定义一个用于产生图片和标签的函数
    def wrap_gen_captcha_text_and_image():
        while True:
            text, image = gen_captcha_text_and_image()
            if image.shape == (IMAGE_HEIGHT, IMAGE_WIDTH, 4):
                return text, image

    for i in range(batch_size0):  # 按batch_size0大小循环产生图片
        text, image = wrap_gen_captcha_text_and_image()  # 产生图片
        image = convert2gray(image)  # 转化成灰度图
        image = np.array(image)
        image=image/255
        # image = image.flatten() / 255  # image.flatten()是转化为一行,除以255是归一化
        # image = np.reshape(np.array(image), [IMAGE_HEIGHT, IMAGE_WIDTH])  # 转换格式:(2080,) => (26,80)
        batch_x[i] =image
        ss=text2vec(text)
        batch_y[i] = np.reshape(text2vec(text), [captcha_num,n_classes])# 转换为标签
    return batch_x, batch_y

#构建lstm网络
def computational_graph_lstm(x, y, global_step):
    #weights and biases of appropriate shape to accomplish above task
    out_weights = tf.Variable(tf.random_normal([num_units,n_classes]), name = 'out_weight')
    out_bias = tf.Variable(tf.random_normal([n_classes]),name = 'out_bias')

    #构建网络
    lstm_layer = [tf.nn.rnn_cell.LSTMCell(num_units, state_is_tuple=True) for _ in range(layer_num)]    #创建两层的lstm
    mlstm_cell = tf.nn.rnn_cell.MultiRNNCell(lstm_layer, state_is_tuple = True)   #将lstm连接在一起
    init_state = mlstm_cell.zero_state(batch_size, tf.float32)  #cell的初始状态

    outputs = list()    #每个cell的输出
    state = init_state
    with tf.variable_scope('RNN'):
        for timestep in range(time_steps):
            if timestep > 0:
                tf.get_variable_scope().reuse_variables()
            (cell_output, state) = mlstm_cell(x[:, timestep, :], state) # 这里的state保存了每一层 LSTM 的状态
            outputs.append(cell_output)
    # h_state = outputs[-1] #取最后一个cell输出

    #计算输出层的第一个元素
    prediction_1 = tf.nn.softmax(tf.matmul(outputs[-4],out_weights)+out_bias)   #获取最后time-step的输出,使用全连接, 得到第一个验证码输出结果
    #计算输出层的第二个元素
    prediction_2 = tf.nn.softmax(tf.matmul(outputs[-3],out_weights)+out_bias)   #输出第二个验证码预测结果
    #计算输出层的第三个元素
    prediction_3 = tf.nn.softmax(tf.matmul(outputs[-2],out_weights)+out_bias)   #输出第三个验证码预测结果
    #计算输出层的第四个元素
    prediction_4 = tf.nn.softmax(tf.matmul(outputs[-1],out_weights)+out_bias)   #输出第四个验证码预测结果,size:[batch,num_class]
    #输出连接
    prediction_all = tf.concat([prediction_1, prediction_2, prediction_3, prediction_4],1)  # 4 * [batch, num_class] => [batch, 4 * num_class]
    prediction_all = tf.reshape(prediction_all,[batch_size, captcha_num, n_classes],name ='prediction_merge') # [4, batch, num_class] => [batch, 4, num_class]

    #loss_function
    # 损失
    with tf.name_scope('loss'):  # 损失
        loss = -tf.reduce_mean(y * tf.log(prediction_all),name = 'loss')
        tf.summary.scalar('loss', loss)
    # loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction_all,labels=y))
    #optimization
    opt = tf.train.AdamOptimizer(learning_rate=learning_rate, name = 'opt').minimize(loss,global_step=global_step)  # 断点续训这里不加global_step=global_step会出错
    #model evaluation
    pre_arg = tf.argmax(prediction_all,2,name = 'predict')
    y_arg = tf.argmax(y,2)
    correct_prediction = tf.equal(pre_arg, y_arg)

    with tf.name_scope('accuracy'):  # 损失
        accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32),name = 'accuracy')
        tf.summary.scalar('accuracy', accuracy)

    return opt, loss, accuracy, pre_arg, y_arg

#训练
def train():
    # defining placeholders
    x = tf.placeholder("float",[None,time_steps,n_input], name = "x") #input image placeholder
    y = tf.placeholder("float",[None,captcha_num,n_classes], name = "y")  #input label placeholder

    # 运行了几轮batch_size的计数器,初值给0,设为不被训练
    global_step = tf.Variable(0, trainable=False)

    # 学习率自然指数衰减
    learing_rate_decay = tf.train.natural_exp_decay(learning_rate, global_step, decaystep, decay_rate, staircase=True)

    # computational graph
    opt, loss, accuracy, pre_arg, y_arg = computational_graph_lstm(x, y, global_step)

    # 创建训练模型保存类
    saver = tf.train.Saver(max_to_keep=1)

    # 初始化变量值
    init = tf.global_variables_initializer()

    # 将图形、训练过程等数据合并在一起
    merged = tf.summary.merge_all()

    with tf.Session() as sess:  # 创建tensorflow session
        sess.run(init)

        writer = tf.summary.FileWriter('logs', sess.graph)  # 将训练日志写入到logs文件夹下

        # ----------断点续训--------------------------
        ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        # ----------断点续训--------------------------

        iter = 1 #迭代次数计数器
        while iter < iteration:
            batch_x, batch_y = get_next_batch(batch_size)
            sess.run(opt, feed_dict={x: batch_x, y: batch_y})   #只运行优化迭代计算图


            if iter %100==0:
                result = sess.run(merged, feed_dict={x: batch_x, y: batch_y})  # 只运行优化迭代计算图
                writer.add_summary(result, iter)  # 将日志数据写入文件

                los, acc, parg, yarg, iter = sess.run([loss, accuracy, pre_arg, y_arg, global_step],feed_dict={x:batch_x,y:batch_y})
                print("iter:%d,Accuracy:%f,Loss:%f " % (iter, acc, los))

            if iter % 1000 == 0:   #保存模型
                # ----------指数衰减型学习率-------------------
                learning_rate_val = sess.run(learing_rate_decay)
                print("After %s steps,learing rate is %f" % (iter, learning_rate_val))
                # ----------指数衰减型学习率-------------------

                # ----------断点续训--------------------------
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
                # ----------断点续训--------------------------

            iter += 1
        # 计算验证集准确率
        valid_x, valid_y = get_next_batch(batch_size)
        print("Validation Accuracy:", sess.run(accuracy, feed_dict={x: valid_x, y: valid_y}))

#预测
def predict():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(path + "/model/" + "LSTM_Captcha-19000.meta")
        saver.restore(sess, tf.train.latest_checkpoint(path + "/model/")) #读取已训练模型

        graph = tf.get_default_graph()  #获取原始计算图,并读取其中的tensor
        x = graph.get_tensor_by_name("x:0")
        y = graph.get_tensor_by_name("y:0")
        pre_arg = graph.get_tensor_by_name("predict:0")

        # test_x, file_list = get_test_set()  #获取测试集
        test_x, test_y =get_next_batch(batch_size)
        batch_test_y = np.zeros([batch_size, captcha_num, n_classes])  # 创建空的y输入
        test_predict = sess.run([pre_arg], feed_dict={x: test_x, y: batch_test_y})
        predict_result=index2char(np.array(test_predict)) #转成字符串
        predict_result = predict_result[:len(test_y)]     #预测结果
        write_to_file(predict_result, test_y)             #保存到文件

#预测结果写入文档
def write_to_file(predict_list, test_y):
    label_y = np.reshape(test_y, [batch_size, captcha_num * n_classes])
    with open(output_path, 'w') as f:
        for i, res in enumerate(predict_list):
            y_ = vec2text(label_y[i]) #转成字符串
            if i == 0:
                f.write("id\tfile\tresult\n")
            f.write(str(i) + "\t" + y_ + "\t" + res + "\n")
            f.write("\n")
    print("预测结果保存在:",output_path)

#训练
train()

#预测
# predict()

2.3 结果:  

 

 

欢迎扫码关注我的微信公众号

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值