[深度学习从入门到女装]keras实战-LSTM(mnist)

本文使用LSTM进行mnist的测试

 

首先是LSTM模块的搭建

def lstm_model(input_shape,n_labels,n_units):
    inputs=Input(input_shape)

    layer=inputs
    print(str(layer.get_shape()))
    layer=LSTM(n_units,return_sequences=True)(layer)
    print(str(layer.get_shape()))
    layer = LSTM(n_units)(layer)
    print(str(layer.get_shape()))
    layer=Dense(n_labels,activation='softmax')(layer)
    print(str(layer.get_shape()))
    outputs = layer

    model = Model(inputs=inputs, outputs=outputs)

    model.compile(optimizer=Adam(),loss=categorical_crossentropy,metrics=['accuracy'])


    return model

return_sequences的意思是,如果是True,则返回每个time的结果,如果为False,则返回最后一个time的结果

 

from tensorflow.examples.tutorials.mnist import input_data
import train
import numpy as np
import keras
import rnn_model


def main(argv=None):
    mnist = input_data.read_data_sets("./MNIST_data", one_hot=True)
    model=rnn_model.lstm_model([28,28],10,50)
    train.train_model(model,"model_file",train.train_generator_data_2d(mnist,50),train.val_generator_data_2d(mnist,50),
                      steps_per_epoch=10,validation_steps=10,n_epochs=20)


def test():
    mnist = input_data.read_data_sets("./MNIST_data", one_hot=True)
    model=train.load_old_model("model_file")
    x, y = mnist.test.next_batch(50)
    x = np.reshape(x, [50, 28, 28])
    y = np.reshape(y, [50, 10])
    loss_and_metrics =model.evaluate(x,y)
    classes=model.predict(x)

    print(loss_and_metrics)
    print(classes)
    classes=np.argmax(classes,1)
    print(classes)



if __name__ == '__main__':
    test()

 

train.py与之前通用

from functools import partial
import math
from keras.callbacks import ModelCheckpoint, CSVLogger, LearningRateScheduler, ReduceLROnPlateau, EarlyStopping, \
    TensorBoard
from keras.models import load_model
from keras.losses import categorical_crossentropy
import numpy as np



def train_generator_data_2d(mnist,batch_size):
    while True:
        x, y = mnist.train.next_batch(batch_size)
        x=np.reshape(x,[batch_size,28,28])
        y=np.reshape(y,[batch_size,10])
        yield (x, y)


def val_generator_data_2d(mnist,batch_size):
    while True:
        x, y = mnist.validation.next_batch(batch_size)

        x = np.reshape(x,[batch_size, 28, 28])
        y = np.reshape(y,[batch_size, 10])
        yield (x, y)


def step_decay(epoch, initial_lrate, drop, epochs_drop):
    return initial_lrate * math.floor((1 + epoch) / float(epochs_drop))


def get_callbacks(model_file, initial_learning_rate=0.0001, learning_rate_drop=0.5, learning_rate_epochs=None,
                  learning_rate_patience=50, logging_file="training.log", verbosity=1, early_stopping_patience=None):
    callbacks = list()

    # weights.{epoch:02d}-{val_loss:.2f}.hdf5
    callbacks.append(ModelCheckpoint(model_file, save_best_only=True))
    callbacks.append(CSVLogger(logging_file, append=True))
    callbacks.append(TensorBoard())

    if learning_rate_epochs:
        callbacks.append(LearningRateScheduler(partial(step_decay, initial_lrate=initial_learning_rate,
                                                       drop=learning_rate_drop, epochs_drop=learning_rate_epochs)))
    else:
        callbacks.append(ReduceLROnPlateau(factor=learning_rate_drop, patience=learning_rate_patience,
                                           verbose=verbosity))
    if early_stopping_patience:
        callbacks.append(EarlyStopping(verbose=verbosity, patience=early_stopping_patience))
    return callbacks


def load_old_model(model_file):
    print("Loading pre-trained model")
    #custom_objects = {'categorical_crossentropy': categorical_crossentropy, 'dice_coefficient': dice_coefficient}

    try:
        from keras_contrib.layers import InstanceNormalization
        #custom_objects["InstanceNormalization"] = InstanceNormalization
    except ImportError:
        pass
    try:
        return load_model(model_file)
    except ValueError as error:
        if 'InstanceNormalization' in str(error):
            raise ValueError(str(error) + "\n\nPlease install keras-contrib to use InstanceNormalization:\n"
                                          "'pip install git+https://www.github.com/keras-team/keras-contrib.git'")
        else:
            raise error


def train_model(model, model_file, training_generator, validation_generator, steps_per_epoch, validation_steps,
                initial_learning_rate=0.001, learning_rate_drop=0.5, learning_rate_epochs=None, n_epochs=500,
                learning_rate_patience=20, early_stopping_patience=None):
    """
    Train a Keras model.
    :param early_stopping_patience: If set, training will end early if the validation loss does not improve after the
    specified number of epochs.
    :param learning_rate_patience: If learning_rate_epochs is not set, the learning rate will decrease if the validation
    loss does not improve after the specified number of epochs. (default is 20)
    :param model: Keras model that will be trained.
    :param model_file: Where to save the Keras model.
    :param training_generator: Generator that iterates through the training data.
    :param validation_generator: Generator that iterates through the validation data.
    :param steps_per_epoch: Number of batches that the training generator will provide during a given epoch.
    :param validation_steps: Number of batches that the validation generator will provide during a given epoch.
    :param initial_learning_rate: Learning rate at the beginning of training.
    :param learning_rate_drop: How much at which to the learning rate will decay.
    :param learning_rate_epochs: Number of epochs after which the learning rate will drop.
    :param n_epochs: Total number of epochs to train the model.
    :return:
    """
    model.fit_generator(generator=training_generator,
                        steps_per_epoch=steps_per_epoch,
                        epochs=n_epochs,
                        validation_data=validation_generator,
                        validation_steps=validation_steps,
                        callbacks=get_callbacks(model_file,
                                                initial_learning_rate=initial_learning_rate,
                                                learning_rate_drop=learning_rate_drop,
                                                learning_rate_epochs=learning_rate_epochs,
                                                learning_rate_patience=learning_rate_patience,
                                                early_stopping_patience=early_stopping_patience))

 

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值