利用keras框架cnn+ctc_loss识别不定长中英文图片

(1)

import keras
import tensorflow as tf
from keras.utils import multi_gpu_model
from keras.models import Model
from keras import backend as K
from keras.optimizers import Adam,SGD
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import Reshape, Dense,Dropout,Lambda
from keras.layers import Input,Conv2D,BatchNormalization,MaxPooling2D,AveragePooling2D,Activation,Softmax

def ctc_lambda_func(args):
    y_pred, y_true, input_x_width, input_y_width = args
    # the 2 is critical here since the first couple outputs of the RNN
    # tend to be garbage:
    y_pred = y_pred[:, :, :]
    return keras.backend.ctc_batch_cost(y_true, y_pred, input_x_width, input_y_width)

def ocr_params():
    params = tf.contrib.trainging.HParams(vocab_size = 66,
            lr = 1e-3, gpu_nums = 1, is_training = True)
    return params

class OcrModel():
    def __init__(self, args):
        self.vocab_size =  args.vocab_size
    def fetature_net(self):
        self.inputs = Input(name = "the_inputs",shape(224,224,3))
        x = AveragePooling2D((8,1), name = 'avg_pool'))(x)
        x = Reshape((m,self.vocab_size))(x)
        
        self.outputs = Activation('softmax')(x)
        self.model = Model(inputs=self.inputs,outputs=self.outputs)
        self.model.summary()

    def ctc_network(self):
        self.labels = Input(name='the_labels',shape=[self.label_max_string_length],dtype='float32')
         self.label_length = Input(name='label_length',shape=[1],dtype='int64')

        self.input_length = Input(name = 'input_length',shape[1],dtype='int64'
        self.loss_out - Lambda(ctc_lambda_func,output_shape=(1,),name ='ctc')[self.outputs,self.labels,self.input_length,self.label_length])
        self.ctc_model =  Model(inputs=[self.inputs, self.labels, self.input_length, sefl.label_length],outputs = self.loss_out)
        self.ctc_model.summary()

    def optimizer_start(self):
        opt = SGD(lr=self.lr,decay=1e-6,momentum= 0.9, nesterov=True,clipnorm=5)
        self.ctc_model.compile(loss={'ctc': lambda y_true, output: output },optimizer=opt)



import os
import tensorflow as tf
from utils import *
from keras.callbacks import ModelCheckpoint,Callback,LearningRateScheduler
from keras import backend as K
from model.cnn_ctc import OcrModel,ocr_params

data_args = ocr_params()
data_args.vocab_size = 66
data_args.gpu_nums = 1
data_args.lr = 1e-03
data_args.is_training= True
ocr = OcrModel(data_args)
nb_epoch = 100
train_batch_size = 16
val_batch_size = 8

if os.path.exists('logs_am/model.h5'):
    print('load acoustic model...')
    am.ctc_model.load_weights('logs_am/model.h5')

epochs = 10
batch_num = len(train_data.wav_lst) // train_data.batch_size

# checkpoint
ckpt = "model_{epoch:02d}-{val_acc:.2f}.hdf5"
checkpoint = ModelCheckpoint(os.path.join('./checkpoint', ckpt), monitor='val_loss', save_weights_only=False, verbose=1, save_best_only=True)

batch_val_gen = SampleDataGenerator(val_sample_lst,path,val_batch_size, shuffle=False)
cnt_val_steps = int(np.floor(len(val_sample_lst)/val_batch_size))

batch_train_gen = SampleDataGenerator(train_sample_lst,path,train_batch_size, shuffle=True)
cnt_train_steps = int(np.floor(len(train_sample_lst)/train_batch_size))



History = ocr.ctc_model.fit_generator(generator=batch_train_gen, steps_per_epoch=cnt_train_steps , epochs=100, callbacks=[checkpoint], workers=1, use_multiprocessing=False, validation_data=batch_val_gen, validation_steps=cnt_val_steps)
ocr.ctc_model.save_weights('logs_am/model.h5')


from keras.preprocessing.image import img_to_array
from keras.utils import Sequence
from PIL import Image
import difflib
import cv2
import matplotlib.pyplot as plt

def get_img_data(file,w,h):
    data = cv2.imread(file)
    data = cv2.resize(data,(w,h)
    data = img_to_array(data)
    data /= 255.0
    return data

class SampleDataGenerator(keras.utils.Sequence):
      
      
      def __init__(self, list_IDs, data_path, batch_size=1, shuffle=False): 

         """
            self.list_IDs:存放所有需要训练的图片文件名的列表。
            self.labels:记录图片标注的分类信息的pandas.DataFrame数据类型,已经预先给定。
            self.batch_size:每次批量生成,训练的样本大小      
         """
          self.sample_list= list_IDs
          self.data_path = data_path
          self.batch_size = batch_size
          self.indexes= np.arrange(len(self.sample_list))
          if self.shuffle == True:
                random.shuffle(self.sample_list)
          self.char2num_dict = char2num_dict

      def __len__(self):
          
          """
             返回生成器的长度,也就是总共分批生成数据的次数。
             
          """
          return int(ceil(len(self.sample_list) / self.batch_size))

     def __getitem__(self, index):
         
         """
            该函数返回每次我们需要的经过处理的数据。
         """
         
         indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
         batch_datas= [self.sample_list[k] for k in indexes ]
         X, Y = self.__data_generation(batch_datas)
         return X,Y

     def on_epoch_end(self):
         
         """
            该函数将在训练时每一个epoch结束的时候自动执行,在这里是随机打乱索引次序以方便下一batch运行。

         """
         if self.shuffle == True:
                np.random.shuffle(self.sample_list)

     def __data_generation(self, batch_data):

        """
           给定文件名,生成数据。
        """
        H, W = 64, 160
        X = np.zeros((self.batch_size, H, W, 3), dtype=np.float32)
        Y = np.zeros((self.batch_size, 8), dtype=np.uint8)

        input_length = np.ones(self.batch_size) * 20
        label_length = np.ones(self.batch_size) * 8

        image_lst = []
        label_lst = []
        
        for line in batch_data:
            img_flie,label,end = line.split('\t')
            image_lst.append(img_file)
            lable_lst.append(label.split(','))

        for index in range(self.batch_size):
            img_file = image_lst[index]
            img_data = get_img_data(self.data_path + img_file,W,H)    
            str_label = lable_lst[index]
            while len(str_lable) < 8:
                str_label += '_'
            num_label = [self.char2num_dict[ch] for ch in str_label]
            X[index] = img_data
            Y[index] = num_label

       return [X,Y,input_length, label_length], np.ones(self.batch_size)

# 定义解码器------------------------------------
def decode_ctc(num_result):
	result = num_result[:, :, :]
	in_len = np.zeros((1), dtype = np.int32)
	in_len[0] = result.shape[1]
	r = K.ctc_decode(result, in_len, greedy = True, beam_width=5, top_paths=1)
	r1 = K.get_value(r[0][0])
	r1 = r1[0]
	return r1



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值