tensorflow2.0 中lstm的实现

tensorflow2.0 中lstm的实现

tensorflow2.0 中lstm的实现

from __future__ import absolute_import, division, print_function, unicode_literals

import collections
import matplotlib.pyplot as plt
import numpy as np

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf

from tensorflow.keras import layers

设计模型

batch_size = 128   # 大小可变,一次处理多少个样本
# Each MNIST image batch is a tensor of shape (batch_size, 28, 28).
# Each input sequence will be of size (28, 28) (height is treated like time).
input_dim = 28  # 必须和特征的列数目相等

units = 64   #  LSTM 中
output_size = 10  # labels are from 0 to 9

Build the RNN model

def build_model(allow_cudnn_kernel=True):
  # CuDNN is only available at the layer level, and not at the cell level.
  # This means `LSTM(units)` will use the CuDNN kernel,
  # while RNN(LSTMCell(units)) will run on non-CuDNN kernel.
  if allow_cudnn_kernel:
    # The LSTM layer with default options uses CuDNN.
    lstm_layer1 = tf.keras.layers.LSTM(units, input_shape=(None, input_dim))
    #lstm_layer2 = tf.keras.layers.LSTM(lstm_layer1)
    
  else:
    # Wrapping a LSTMCell in a RNN layer will not use CuDNN.
    lstm_layer = tf.keras.layers.RNN(
         tf.keras.layers.LSTMCell(units),
        input_shape=(None, input_dim))
    
  model = tf.keras.models.Sequential([
  
  lstm_layer1,
  #lstm_layer2,
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dense(output_size, activation='softmax')]

)
return model

观察模型结构

model = build_model(allow_cudnn_kernel=True)
model.summary()

载入数据

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
sample, sample_label = x_train[0], y_train[0]

训练

model.compile(loss='sparse_categorical_crossentropy', 
              optimizer='sgd',
              metrics=['accuracy'])

绘图

history = model.fit(x_train, y_train,
          validation_data=(x_test, y_test),
          batch_size=batch_size,
          epochs=5)
  • 3
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值