验证码识别3(整体训练)——训练部分

我们在上一篇中生成了tfrecord文件,下面我们就要对其进行训练了。

首先讲一下训练的方法,我们这里是把一个图片的名字转换成one-hot格式,每个数字为10位,每个样本一共4个数字,也就是每个标签对应40位。训练的方法上使用了alexnet网络,你需要下载一个文件,地址是链接:https://pan.baidu.com/s/17aMOksxsOyUIN_XmqL8rBg 提取码:mcym,将其解压到你的写代码的目录下,名字为nets,接下来给出代码(需要改动的是tfrecord文件的路径,和最后模型保存的路径):

# coding: utf-8
import os
import tensorflow as tf
from PIL import Image
from nets import nets_factory
import numpy as np
import numpy as np
np.set_printoptions(threshold=np.inf)


# In[2]:

# 不同字符数量
CHAR_SET_LEN = 40
# 图片高度
IMAGE_HEIGHT = 60
# 图片宽度
IMAGE_WIDTH = 160
# 批次
BATCH_SIZE = 32
# tfrecord文件存放路径
TFRECORD_FILE = "D:/验证码识别/使用一个标签来训练/train.tfrecords"

# placeholder
x = tf.placeholder(tf.float32, [None, 224, 224])

y = tf.placeholder(tf.float32, [None])
print(y)
# 学习率
lr = tf.Variable(0.003, dtype=tf.float32)


# 从tfrecord读出数据
def read_and_decode(filename):
    # 根据文件名生成一个队列
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    # 返回文件名和文件
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'image': tf.FixedLenFeature([], tf.string),
                                           'label': tf.FixedLenFeature([], tf.int64)})
    # 获取图片数据
    image = tf.decode_raw(features['image'], tf.uint8)
    img = tf.reshape(image,[224,224])
    # tf.train.shuffle_batch必须确定shape
    image = tf.reshape(image, [224, 224])
    # 图片预处理
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.subtract(image, 0.5)
    image = tf.multiply(image, 2.0)
    # 获取label
    label= tf.cast(features['label'], tf.float64)
    return img, image, label
# 获取图片数据和标签
img,image, label = read_and_decode(TFRECORD_FILE)
la = label
# 使用shuffle_batch可以随机打乱
image_batch, label_batch = tf.train.shuffle_batch(
    [image, label], batch_size=BATCH_SIZE,
    capacity=50000, min_after_dequeue=10000, num_threads=1)

# 定义网络结构
train_network_fn = nets_factory.get_network_fn(
    'alexnet_v2',
    num_classes=CHAR_SET_LEN,
    weight_decay=0.0005,
    is_training=True)

with tf.Session() as sess:
    # inputs: a tensor of size [batch_size, height, width, channels]
    X = tf.reshape(x, [BATCH_SIZE, 224, 224, 1])
    # 数据输入网络得到输出值
    logits, end_points = train_network_fn(X)
    #print(logits)
    #print(label_batch)

    qian = y / 1000 % 10
    print(qian)
    bai = y / 100 % 10
    shi = y / 10 % 10
    ge = y % 10
    one_hot_labels0 = tf.one_hot(indices=tf.cast(qian, tf.int32), depth=10)
    one_hot_labels1 = tf.one_hot(indices=tf.cast(bai, tf.int32), depth=10)
    one_hot_labels2 = tf.one_hot(indices=tf.cast(shi, tf.int32), depth=10)
    one_hot_labels3 = tf.one_hot(indices=tf.cast(ge, tf.int32), depth=10)
    label_40 = tf.concat([one_hot_labels0, one_hot_labels1, one_hot_labels2, one_hot_labels3], axis=1)

    #label_40 = tf.reshape(label_40,[50,40])
    # 计算loss
    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=label_40))
    # 优化loss
    optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss)
    # 计算准确率
    correct_prediction = tf.equal(tf.argmax(label_40, 1), tf.argmax(logits, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


    # 用于保存模型
    saver = tf.train.Saver()
    # 初始化
    sess.run(tf.global_variables_initializer())

    # 创建一个协调器,管理线程
    coord = tf.train.Coordinator()
    # 启动QueueRunner, 此时文件名队列已经进队
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(13001):
        # 获取一个批次的数据和标签
        b_image, b_label = sess.run([image_batch, label_batch])
        # 优化模型
        sess.run(optimizer, feed_dict={x: b_image, y: b_label})
        #print(sess.run(label_40,feed_dict={x: b_image, y: b_label}))
        #a = tf.concat([tf.argmax(one_hot_labels0, 1), tf.argmax(one_hot_labels1, 1), tf.argmax(one_hot_labels2, 1),
                       #tf.argmax(one_hot_labels3, 1)], 0)
        # imge = Image.fromarray(img, 'L')
        # imge.save('./' + '_''Label_' + str(a) + '.jpg')
        # 每迭代20次计算一次loss和准确率
        if i % 20 == 0:
            # 每迭代2000次降低一次学习率
            if i % 3000 == 0:
                sess.run(tf.assign(lr, lr / 3))
            acc0,loss_ = sess.run([accuracy,loss],feed_dict={x: b_image,y: b_label})
            learning_rate = sess.run(lr)
            print("Iter:%d  Loss:%.3f  Accuracy:%.2f  Learning_rate:%.4f" %(i,loss_,acc0,learning_rate))
            if i == 13000:
                saver.save(sess, "D:/验证码识别/使用一个标签来训练/", global_step=i)
                break
                # 通知其他线程关闭
    coord.request_stop()
    # 其他所有线程关闭之后,这一函数才能返回
    coord.join(threads)




训练时间很长,我这里已经训练好了,有需要的可以下载:

另外还多了模型文件

接下来我们需要验证我们的网络准不准确。请看下一篇文章验证码识别4——验证部分
 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值