一,基本版本
1.1 训练图片示例:
1.2 代码:
#-*- coding:utf-8 -*
import tensorflow as tf
import os
import random
import numpy as np
from PIL import Image
path = os.getcwd() #项目所在路径
captcha_path = path + '/train_data' #训练集-验证码所在路径
validation_path = path + '/validation_data' #验证集-验证码所在路径
test_data_path = path + '/test_data' #测试集-验证码文件存放路径
output_path = path + '/result/result.txt' #测试结果存放路径
model_path = path + '/model/model.ckpt' #模型存放路径
#要识别的字符
number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
batch_size = 64 #size of batch
time_steps = 26 #unrolled through 28 time steps #每个time_step是图像的一行像素 height
n_input = 80 #rows of 28 pixels #width
image_channels = 1 # 图像的通道数
captcha_num = 4 # 验证码中字符个数
n_classes = len(number) + len(ALPHABET) #类别分类
learning_rate = 0.001 #learning rate for adam
num_units = 128 #hidden LSTM units
layer_num = 2 #网络层数
iteration = 10000 #训练迭代次数
def computational_graph_lstm(x, y, batch_size = batch_size):
#weights and biases of appropriate shape to accomplish above task
out_weights = tf.Variable(tf.random_normal([num_units,n_classes]), name = 'out_weight')
out_bias = tf.Variable(tf.random_normal([n_classes]),name = 'out_bias')
#构建网络
lstm_layer = [tf.nn.rnn_cell.LSTMCell(num_units, state_is_tuple=True) for _ in range(layer_num)] #创建两层的lstm
mlstm_cell = tf.nn.rnn_cell.MultiRNNCell(lstm_layer, state_is_tuple = True) #将lstm连接在一起
init_state = mlstm_cell.zero_state(batch_size, tf.float32) #cell的初始状态
outputs = list() #每个cell的输出
state = init_state
with tf.variable_scope('RNN'):
for timestep in range(time_steps):
if timestep > 0:
tf.get_variable_scope().reuse_variables()
(cell_output, state) = mlstm_cell(x[:, timestep, :], state) # 这里的state保存了每一层 LSTM 的状态
outputs.append(cell_output)
# h_state = outputs[-1] #取最后一个cell输出
#计算输出层的第一个元素
prediction_1 = tf.nn.softmax(tf.matmul(outputs[-4],out_weights)+out_bias) #获取最后time-step的输出,使用全连接, 得到第一个验证码输出结果
#计算输出层的第二个元素
prediction_2 = tf.nn.softmax(tf.matmul(outputs[-3],out_weights)+out_bias) #输出第二个验证码预测结果
#计算输出层的第三个元素
prediction_3 = tf.nn.softmax(tf.matmul(outputs[-2],out_weights)+out_bias) #输出第三个验证码预测结果
#计算输出层的第四个元素
prediction_4 = tf.nn.softmax(tf.matmul(outputs[-1],out_weights)+out_bias) #输出第四个验证码预测结果,size:[batch,num_class]
#输出连接
prediction_all = tf.concat([prediction_1, prediction_2, prediction_3, prediction_4],1) # 4 * [batch, num_class] => [batch, 4 * num_class]
prediction_all = tf.reshape(prediction_all,[batch_size, captcha_num, n_classes],name ='prediction_merge') # [4, batch, num_class] => [batch, 4, num_class]
#loss_function
loss = -tf.reduce_mean(y * tf.log(prediction_all),name = 'loss')
# loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction_all,labels=y))
#optimization
opt = tf.train.AdamOptimizer(learning_rate=learning_rate, name = 'opt').minimize(loss)
#model evaluation
pre_arg = tf.argmax(prediction_all,2,name = 'predict')
y_arg = tf.argmax(y,2)
correct_prediction = tf.equal(pre_arg, y_arg)
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32),name = 'accuracy')
return opt, loss, accuracy, pre_arg, y_arg
#获取bacth_size数据集
def get_batch(data_path = captcha_path, is_training = True):
target_file_list = os.listdir(data_path) #读取路径下的所有文件名
batch = batch_size if is_training else len(target_file_list) # 确认batch 大小
batch_x = np.zeros([batch, time_steps, n_input]) #batch 数据
batch_y = np.zeros([batch, captcha_num, n_classes]) # batch 标签
for i in range(batch):
file_name = random.choice(target_file_list) if is_training else target_file_list[i] #确认要打开的文件名
img = Image.open(data_path + '/' + file_name) #打开图片
img = np.array(img)
if len(img.shape) > 2: #彩色图
img = np.mean(img, -1) #转换成灰度图像:(26,80,3) =>(26,80)
img = img / 255 #标准化,为了防止训练集的方差过大而导致的收敛过慢问题。
# img = np.reshape(img,[time_steps,n_input]) #转换格式:(2080,) => (26,80)
batch_x[i] = img
label = np.zeros(captcha_num * n_classes)
for num, char in enumerate(file_name.split('.')[0]):
index = num * n_classes + char2index(char)
label[index] = 1
label = np.reshape(label,[captcha_num, n_classes])
batch_y[i] = label
return batch_x, batch_y
#字符转换成000100
def char2index(c):
k = ord(c)
index = -1
if k >= 48 and k <= 57: #数字索引
index = k - 48
if k >= 65 and k <= 90: #大写字母索引
index = k - 55
if k >= 97 and k <= 122: #小写字母索引
index = k - 61
if index == -1:
raise ValueError('No Map')
return index
#000100转换成字符
def index2char(k):
# k = chr(num)
index = -1
if k >= 0 and k < 10: #数字索引
index = k + 48
if k >= 10 and k < 36: #大写字母索引
index = k + 55
if k >= 36 and k < 62: #小写字母索引
index = k + 61
if index == -1:
raise ValueError('No Map')
return chr(index)
#训练
def train():
# defining placeholders
x = tf.placeholder("float",[None,time_steps,n_input], name = "x") #input image placeholder
y = tf.placeholder("float",[None,captcha_num,n_classes], name = "y") #input label placeholder
# computational graph
opt, loss, accuracy, pre_arg, y_arg = computational_graph_lstm(x, y)
saver = tf.train.Saver() # 创建训练模型保存类
init = tf.global_variables_initializer() #初始化变量值
with tf.Session() as sess: # 创建tensorflow session
sess.run(init)
iter = 1
while iter < iteration:
batch_x, batch_y = get_batch()
sess.run(opt, feed_dict={x: batch_x, y: batch_y}) #只运行优化迭代计算图
if iter %100==0:
los, acc, parg, yarg = sess.run([loss, accuracy, pre_arg, y_arg],feed_dict={x:batch_x,y:batch_y})
print("For iter ",iter)
print("Accuracy ",acc)
print("Loss ",los)
if iter % 1000 ==0:
print("predict arg:",parg[0:10])
print("yarg :",yarg[0:10])
print("__________________")
# if acc > 0.95:
# print("training complete, accuracy:", acc)
# break
if iter % 1000 == 0: #保存模型
saver.save(sess, model_path, global_step=iter)
iter += 1
# 计算验证集准确率
valid_x, valid_y = get_batch(data_path=validation_path, is_training=False)
print("Validation Accuracy:", sess.run(accuracy, feed_dict={x: valid_x, y: valid_y}))
#获取测试集
def get_test_set():
target_file_list = os.listdir(test_data_path) #获取测试集路径下的所有文件
print("预测的验证码文件:",len(target_file_list))
#判断条件
flag = len(target_file_list) // batch_size #计算待检测验证码个数能被batch size 整除的次数
batch_len = flag if flag > 0 else 1 #共有多少个batch
flag2 = len(target_file_list) % batch_size #计算验证码被batch size整除后的取余
batch_len = batch_len if flag2 == 0 else batch_len + 1 #若不能整除,则batch数量加1
print("共生成batch数:",batch_len)
print("验证码根据batch取余:",flag2)
batch = np.zeros([batch_len * batch_size, time_steps, n_input])
for i, file in enumerate(target_file_list):
batch[i] = open_iamge(file)
batch = batch.reshape([batch_len, batch_size, time_steps, n_input])
return batch, target_file_list #batch_file_name
#打开图像
def open_iamge(file):
img = Image.open(test_data_path + '/' + file) #打开图片
img = np.array(img)
if len(img.shape) > 2:
img = np.mean(img, -1) #转换成灰度图像:(26,80,3) =>(26,80)
img = img / 255
return img
#预测
def predict():
with tf.Session() as sess:
saver = tf.train.import_meta_graph(path + "/model/" + "model.ckpt-5000.meta")
saver.restore(sess, tf.train.latest_checkpoint(path + "/model/")) #读取已训练模型
graph = tf.get_default_graph() #获取原始计算图,并读取其中的tensor
x = graph.get_tensor_by_name("x:0")
y = graph.get_tensor_by_name("y:0")
pre_arg = graph.get_tensor_by_name("predict:0")
test_x, file_list = get_test_set() #获取测试集
predict_result = []
for i in range(len(test_x)):
batch_test_x = test_x[i]
batch_test_y = np.zeros([batch_size, captcha_num,n_classes]) #创建空的y输入
test_predict = sess.run([pre_arg], feed_dict={x: batch_test_x, y:batch_test_y})
# print(test_predict)
# predict_result.extend(test_predict)
for line in test_predict[0]: #将预测结果转换为字符
character = ""
for each in line:
character += index2char(each)
predict_result.append(character)
predict_result = predict_result[:len(file_list)] #预测结果
write_to_file(predict_result, file_list) #保存到文件
#写入文档
def write_to_file(predict_list, file_list):
with open(output_path, 'a') as f:
for i, res in enumerate(predict_list):
if i == 0:
f.write("id\tfile\tresult\n")
f.write(str(i) + "\t" + file_list[i] + "\t" + res + "\n")
print("预测结果保存在:",output_path)
#训练
train()
#预测
predict()
1.3 结果:
二,改进版本
改进版本改成了实时生成图片,不用读取本地图片,无限生成数据集。代码改动还是比较大的,有兴趣可以对比一下。同时添加了一些功能,比如:可视化,指数下降学习率,断点续训等等
2.1 训练图片示例:
可以根据个人口味酌情更改验证码生成样式:
2.2 代码:
2.2.1 验证码生成代码一:
#encoding=utf-8
import random
# import matplotlib.pyplot as plt
import string
import sys
import math
from PIL import Image,ImageDraw,ImageFont,ImageFilter
filename="./My_captcha/"
#字体的位置,不同版本的系统会有不同BuxtonSketch.ttf
font_path = 'C:/Windows/Fonts/Georgia.ttf'
#font_path = 'C:/Windows/Fonts/默陌肥圆手写体.ttf'
#生成几位数的验证码
number = 4
#生成验证码图片的高度和宽度
size = (80,26)
#背景颜色,默认为白色
bgcolor = (255,255,255)
#字体颜色,默认为蓝色
fontcolor = (0,0,0)
#干扰线颜色。默认为红色
linecolor = (0,0,0)
#是否要加入干扰线
draw_line = True
#加入干扰线条数的上下限
line_number = (1,5)
#用来随机生成一个字符串
def gene_text():
# source = list(string.letters)
# for index in range(0,10):
# source.append(str(index))
source = ['0','1','2','3','4','5','6','7','8','9','A', 'B', 'C', 'D', 'E', 'F', 'G', 'H','I','J', 'K','L', 'M', 'N','O','P','Q','R',
'S', 'T', 'U', 'V', 'W', 'Z','X', 'Y']
# source = [ 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H','I','J', 'K','L', 'M', 'N','O','P','Q','R',
# 'S', 'T', 'U', 'V', 'W', 'Z','X', 'Y']
return ''.join(random.sample(source,number))#number是生成验证码的位数
#用来绘制干扰线
def gene_line(draw,width,height):
# begin = (random.randint(0, width), random.randint(0, height))
# end = (random.randint(0, width), random.randint(0, height))
begin = (0, random.randint(0, height))
end = (74, random.randint(0, height))
draw.line([begin, end], fill = linecolor,width=3)
#生成验证码
def gene_code():
width,height = size #宽和高
image = Image.new('RGBA',(width,height),bgcolor) #创建图片
font = ImageFont.truetype(font_path,25) #验证码的字体
draw = ImageDraw.Draw(image) #创建画笔
text = gene_text() #生成字符串
font_width, font_height = font.getsize(text)
draw.text(((width - font_width) / number, (height - font_height) / number),text,\
font= font,fill=fontcolor) #填充字符串
if draw_line:
gene_line(draw,width,height)
image = image.transform((width+30,height+10), Image.AFFINE, (1,-0.3,0,-0.1,1,0),Image.BILINEAR) #创建扭曲
# image = image.transform((width+20,height+10), Image.AFFINE, (1,-0.3,0,-0.1,1,0),Image.BILINEAR) #创建扭曲
image = image.filter(ImageFilter.EDGE_ENHANCE_MORE) #滤镜,边界加强
# a = str(m)
aa = str(".png")
path = filename + text + aa
# cv2.imwrite(path, I1)
# image.save('idencode.jpg') #保存验证码图片
image.save(path)
x=1
# if __name__ == "__main__":
# for k in(1,1000):
while x<10:
gene_code()
x+=1
if(x%100==0):
print("Iter:%d" % x)
2.2.2 验证码生成代码2:
from captcha.image import ImageCaptcha # pip install captcha
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import random
# 验证码中的字符, 就不用汉字了
number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u',
'v', 'w', 'x', 'y', 'z']
ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U',
'V', 'W', 'X', 'Y', 'Z']
# 验证码一般都无视大小写;验证码长度4个字符
def random_captcha_text(char_set=number+ ALPHABET, captcha_size=4):
captcha_text = []
for i in range(captcha_size):
c = random.choice(char_set)
captcha_text.append(c)
return captcha_text
# 生成字符对应的验证码
def gen_captcha_text_and_image():
image = ImageCaptcha()
captcha_text = random_captcha_text()
captcha_text = ''.join(captcha_text)
captcha = image.generate(captcha_text)
captcha_image = Image.open(captcha)
captcha_image = np.array(captcha_image) #转成np.array
captcha_image=Image.fromarray(np.uint8(captcha_image)) #转成PIL Image
captcha_image = captcha_image.resize((80, 26), Image.ANTIALIAS) #缩放
captcha_image.save("./imgs/"+captcha_text + '.jpg') #存图
# captcha_image.write(captcha_text, "./imgs/"+captcha_text + '.jpg') # 写到文件
return captcha_text, captcha_image
if __name__ == '__main__':
# 测试
for i in range(10000):
text, image = gen_captcha_text_and_image()
if(i%100==0):
print("生成第%s张图" % i)
# f = plt.figure()
# ax = f.add_subplot(111)
# ax.text(0.1, 0.9, text, ha='center', va='center', transform=ax.transAxes)
# plt.imshow(image)
#
# plt.show()
2.2.3 模型训练整体代码:
#-*- coding:utf-8 -*
import os
import random
import captcha
import numpy as np
import tensorflow as tf
from captcha.image import ImageCaptcha # pip install captcha
from PIL import Image,ImageDraw,ImageFont,ImageFilter
#########全局变量###########################################
path = os.getcwd() #项目所在路径
output_path = path + '/result/result.txt' #测试结果存放路径
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "LSTM_Captcha" # 保存模型名称
#要识别的字符
number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
batch_size = 64 # size of batch
time_steps = 26 # 每个time_step是图像的一行像素 height
n_input = 80 # rows of 28 pixels #width
image_channels = 1 # 图像的通道数
captcha_num = 4 # 验证码中字符个数
n_classes = len(number) + len(ALPHABET) #类别分类
learning_rate = 0.001 #learning rate for adam
decaystep = 5000 # 实现衰减的频率
decay_rate = 0.5 # 衰减率
num_units = 64 #hidden LSTM units
layer_num = 2 #网络层数
iteration = 20000 #训练迭代次数
#自动生成图像
IMAGE_HEIGHT = 26 # 图像高
IMAGE_WIDTH = 80 # 图像宽
#生成验证码图片的宽度和高度
size = (IMAGE_WIDTH,IMAGE_HEIGHT)
#背景颜色,默认为白色
bgcolor = (255,255,255)
#字体颜色,默认为黑色
fontcolor = (0,0,0)
#字体的位置,不同版本的系统会有不同BuxtonSketch.ttf
font_path = 'C:/Windows/Fonts/Georgia.ttf'
#########全局变量###########################################
# 随机生成4个数字+大小写字母的数组
def random_captcha_text(char_set=number+ALPHABET, captcha_size=4): #数字
captcha_text = [] # 初始化一个空列表
for i in range(captcha_size): # 产生字符的个数
c = random.choice(char_set) # 随机产生数字
captcha_text.append(c) # 加入列表
return ''.join(captcha_text) # 返回生成的字符
# 随机生成4个数字的图片
def gen_captcha_text_and_image():
width,height = size #宽和高
image = Image.new('RGBA',(width,height),bgcolor) #创建图片
font = ImageFont.truetype(font_path,25) #验证码的字体
draw = ImageDraw.Draw(image) #创建画笔
captcha_text = random_captcha_text() # 随机生成4个数字的数组
font_width, font_height = font.getsize(captcha_text) #字体大小
draw.text(((width - font_width) / captcha_num, (height - font_height) / captcha_num),\
captcha_text,font= font,fill=fontcolor) #填充字符串
image = image.filter(ImageFilter.EDGE_ENHANCE_MORE) # 滤镜,边界加强
# aa = str(".png")
# path = "./" + captcha_text + aa
# image.save(path)
captcha_image = np.array(image) # 转化成array数组
return captcha_text, captcha_image
# 转换成灰度图
def convert2gray(img):
if len(img.shape) > 2:
gray = np.mean(img, -1)
# 上面的转法较快,正规转法如下
# r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
# gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray
else:
return img
# 字符串转换成0000100的数组
def text2vec(text):
text_len = len(text)
if text_len > captcha_num:
raise ValueError('验证码最长4个字符')
vector = np.zeros(captcha_num*n_classes)
def char2pos(c):
if c == '_':
k = 62
return k
k = ord(c) - 48
if k > 9:
k = ord(c) - 55
if k > 35:
k = ord(c) - 61
if k > 61:
raise ValueError('No Map')
return k
for i, c in enumerate(text):
idx = i * n_classes + char2pos(c)
vector[idx] = 1
return vector
# 0000100的数组转换成字符串
def vec2text(vec):
char_pos = vec.nonzero()[0]
text = []
for i, c in enumerate(char_pos):
char_at_pos = i # c/63
char_idx = c % n_classes
if char_idx < 10:
char_code = char_idx + ord('0')
elif char_idx < 36:
char_code = char_idx - 10 + ord('A')
elif char_idx < 62:
char_code = char_idx - 36 + ord('a')
elif char_idx == 62:
char_code = ord('_')
else:
raise ValueError('error')
text.append(chr(char_code))
return "".join(text)
# [22,32,1,5]类型转换成字符
def index2char(vec):
text=[]
chr=''
for i in range(len(vec[0])):
subVec=vec[0][i]
listChr=[]
for id in range(captcha_num):
if subVec[id]<10:
chr=number[subVec[id]]
listChr.append(chr)
elif subVec[id]<36:
chr=ALPHABET[subVec[id]-10]
listChr.append(chr)
elif subVec[id] < 62:
chr = ALPHABET[subVec[id] - 36]
listChr.append(chr)
elif subVec[id] == 62:
listChr.append('_')
else:
raise ValueError('error')
str=''.join(listChr)
text.append(str)
return text
# 产生用于训练的bacth_size0大小的数据集
def get_next_batch(batch_size0=64):
batch_x = np.zeros([batch_size0, time_steps, n_input])
batch_y = np.zeros([batch_size0, captcha_num, n_classes])
# 内部定义一个用于产生图片和标签的函数
def wrap_gen_captcha_text_and_image():
while True:
text, image = gen_captcha_text_and_image()
if image.shape == (IMAGE_HEIGHT, IMAGE_WIDTH, 4):
return text, image
for i in range(batch_size0): # 按batch_size0大小循环产生图片
text, image = wrap_gen_captcha_text_and_image() # 产生图片
image = convert2gray(image) # 转化成灰度图
image = np.array(image)
image=image/255
# image = image.flatten() / 255 # image.flatten()是转化为一行,除以255是归一化
# image = np.reshape(np.array(image), [IMAGE_HEIGHT, IMAGE_WIDTH]) # 转换格式:(2080,) => (26,80)
batch_x[i] =image
ss=text2vec(text)
batch_y[i] = np.reshape(text2vec(text), [captcha_num,n_classes])# 转换为标签
return batch_x, batch_y
#构建lstm网络
def computational_graph_lstm(x, y, global_step):
#weights and biases of appropriate shape to accomplish above task
out_weights = tf.Variable(tf.random_normal([num_units,n_classes]), name = 'out_weight')
out_bias = tf.Variable(tf.random_normal([n_classes]),name = 'out_bias')
#构建网络
lstm_layer = [tf.nn.rnn_cell.LSTMCell(num_units, state_is_tuple=True) for _ in range(layer_num)] #创建两层的lstm
mlstm_cell = tf.nn.rnn_cell.MultiRNNCell(lstm_layer, state_is_tuple = True) #将lstm连接在一起
init_state = mlstm_cell.zero_state(batch_size, tf.float32) #cell的初始状态
outputs = list() #每个cell的输出
state = init_state
with tf.variable_scope('RNN'):
for timestep in range(time_steps):
if timestep > 0:
tf.get_variable_scope().reuse_variables()
(cell_output, state) = mlstm_cell(x[:, timestep, :], state) # 这里的state保存了每一层 LSTM 的状态
outputs.append(cell_output)
# h_state = outputs[-1] #取最后一个cell输出
#计算输出层的第一个元素
prediction_1 = tf.nn.softmax(tf.matmul(outputs[-4],out_weights)+out_bias) #获取最后time-step的输出,使用全连接, 得到第一个验证码输出结果
#计算输出层的第二个元素
prediction_2 = tf.nn.softmax(tf.matmul(outputs[-3],out_weights)+out_bias) #输出第二个验证码预测结果
#计算输出层的第三个元素
prediction_3 = tf.nn.softmax(tf.matmul(outputs[-2],out_weights)+out_bias) #输出第三个验证码预测结果
#计算输出层的第四个元素
prediction_4 = tf.nn.softmax(tf.matmul(outputs[-1],out_weights)+out_bias) #输出第四个验证码预测结果,size:[batch,num_class]
#输出连接
prediction_all = tf.concat([prediction_1, prediction_2, prediction_3, prediction_4],1) # 4 * [batch, num_class] => [batch, 4 * num_class]
prediction_all = tf.reshape(prediction_all,[batch_size, captcha_num, n_classes],name ='prediction_merge') # [4, batch, num_class] => [batch, 4, num_class]
#loss_function
# 损失
with tf.name_scope('loss'): # 损失
loss = -tf.reduce_mean(y * tf.log(prediction_all),name = 'loss')
tf.summary.scalar('loss', loss)
# loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction_all,labels=y))
#optimization
opt = tf.train.AdamOptimizer(learning_rate=learning_rate, name = 'opt').minimize(loss,global_step=global_step) # 断点续训这里不加global_step=global_step会出错
#model evaluation
pre_arg = tf.argmax(prediction_all,2,name = 'predict')
y_arg = tf.argmax(y,2)
correct_prediction = tf.equal(pre_arg, y_arg)
with tf.name_scope('accuracy'): # 损失
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32),name = 'accuracy')
tf.summary.scalar('accuracy', accuracy)
return opt, loss, accuracy, pre_arg, y_arg
#训练
def train():
# defining placeholders
x = tf.placeholder("float",[None,time_steps,n_input], name = "x") #input image placeholder
y = tf.placeholder("float",[None,captcha_num,n_classes], name = "y") #input label placeholder
# 运行了几轮batch_size的计数器,初值给0,设为不被训练
global_step = tf.Variable(0, trainable=False)
# 学习率自然指数衰减
learing_rate_decay = tf.train.natural_exp_decay(learning_rate, global_step, decaystep, decay_rate, staircase=True)
# computational graph
opt, loss, accuracy, pre_arg, y_arg = computational_graph_lstm(x, y, global_step)
# 创建训练模型保存类
saver = tf.train.Saver(max_to_keep=1)
# 初始化变量值
init = tf.global_variables_initializer()
# 将图形、训练过程等数据合并在一起
merged = tf.summary.merge_all()
with tf.Session() as sess: # 创建tensorflow session
sess.run(init)
writer = tf.summary.FileWriter('logs', sess.graph) # 将训练日志写入到logs文件夹下
# ----------断点续训--------------------------
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
# ----------断点续训--------------------------
iter = 1 #迭代次数计数器
while iter < iteration:
batch_x, batch_y = get_next_batch(batch_size)
sess.run(opt, feed_dict={x: batch_x, y: batch_y}) #只运行优化迭代计算图
if iter %100==0:
result = sess.run(merged, feed_dict={x: batch_x, y: batch_y}) # 只运行优化迭代计算图
writer.add_summary(result, iter) # 将日志数据写入文件
los, acc, parg, yarg, iter = sess.run([loss, accuracy, pre_arg, y_arg, global_step],feed_dict={x:batch_x,y:batch_y})
print("iter:%d,Accuracy:%f,Loss:%f " % (iter, acc, los))
if iter % 1000 == 0: #保存模型
# ----------指数衰减型学习率-------------------
learning_rate_val = sess.run(learing_rate_decay)
print("After %s steps,learing rate is %f" % (iter, learning_rate_val))
# ----------指数衰减型学习率-------------------
# ----------断点续训--------------------------
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
# ----------断点续训--------------------------
iter += 1
# 计算验证集准确率
valid_x, valid_y = get_next_batch(batch_size)
print("Validation Accuracy:", sess.run(accuracy, feed_dict={x: valid_x, y: valid_y}))
#预测
def predict():
with tf.Session() as sess:
saver = tf.train.import_meta_graph(path + "/model/" + "LSTM_Captcha-19000.meta")
saver.restore(sess, tf.train.latest_checkpoint(path + "/model/")) #读取已训练模型
graph = tf.get_default_graph() #获取原始计算图,并读取其中的tensor
x = graph.get_tensor_by_name("x:0")
y = graph.get_tensor_by_name("y:0")
pre_arg = graph.get_tensor_by_name("predict:0")
# test_x, file_list = get_test_set() #获取测试集
test_x, test_y =get_next_batch(batch_size)
batch_test_y = np.zeros([batch_size, captcha_num, n_classes]) # 创建空的y输入
test_predict = sess.run([pre_arg], feed_dict={x: test_x, y: batch_test_y})
predict_result=index2char(np.array(test_predict)) #转成字符串
predict_result = predict_result[:len(test_y)] #预测结果
write_to_file(predict_result, test_y) #保存到文件
#预测结果写入文档
def write_to_file(predict_list, test_y):
label_y = np.reshape(test_y, [batch_size, captcha_num * n_classes])
with open(output_path, 'w') as f:
for i, res in enumerate(predict_list):
y_ = vec2text(label_y[i]) #转成字符串
if i == 0:
f.write("id\tfile\tresult\n")
f.write(str(i) + "\t" + y_ + "\t" + res + "\n")
f.write("\n")
print("预测结果保存在:",output_path)
#训练
train()
#预测
# predict()
2.3 结果: