1. 引入库
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2'
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import models, layers, optimizers
2. 读入数据
(x_train_all, y_train_all), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train_all, x_test = x_train_all / 255.0, x_test / 255.0
x_train, x_test = x_train_all[:4000], x_train_all[4000:5000]
y_train, y_test = y_train_all[:4000], y_train_all[4000:5000]
3. 定义MogrifierLSTMCell
def weight(units, factorize_k=None):
if factorize_k is None:
return keras.layers.Dense(units, activation=None, use_bias=False)
assert factorize_k < units
return keras.models.Sequential([
keras.layers.Dense(factorize_k, activation=None, use_bias=False),
keras.layers.Dense(units, activation=None, use_bias=False)
])
class MogrifierCell(keras.layers.LSTMCell):
def __init__(self, units, iters, factorize_k=None, **kwargs):
self.iters = iters
self.factorize_k = factorize_k
self.Q = weight(units, factorize_k)
self.R = weight(units, factorize_k) if iters > 1 else None
super().__init__(units, **kwargs)
def call(self, input_at_t, state_at_t, **kwargs):
if input_at_t is not None and state_at_t is not None:
shape = input_at_t.shape
*_, units = shape
h = state_at_t[0]
c = state_at_t[1]
x = tf.reshape(input_at_t, (-1, units))
h = tf.reshape(h, (-1, units))
for ind in range(self.iters):
if (ind % 2) != 0:
x = 2 * tf.sigmoid(self.Q(h)) * x
else:
h = 2 * tf.sigmoid(self.Q(x)) * h
input_at_t = tf.reshape(x, [-1,input_at_t.shape[-1]])
state_at_t = (tf.reshape(h, [-1,input_at_t.shape[-1]]), c)
return super().call(inputs=input_at_t, states=state_at_t)
4. 构建网络模型
inputs = layers.Input(shape=(x_train.shape[1:]), name='inputs')
lstm1_out = layers.LSTM(units=256, return_sequences=True, name='lstm_1')(inputs)
Mogrifier =layers.RNN(MogrifierCell(256, 2), return_sequences=True, name='mogrifier_lstm')(lstm1_out)
lstm2_out = layers.LSTM(units=128, return_sequences=False,name='lstm_2')(Mogrifier)
outputs = layers.Dense(10, activation='softmax',name='dense')(lstm2_out)
lstm = keras.Model(inputs, outputs)
lstm.summary()
5. 训练及测试
lstm.compile(optimizer=keras.optimizers.Adam(0.001),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
history = lstm.fit(x_train, y_train, batch_size=32, epochs=1, validation_split=0.1)
lstm.evaluate(x_test, y_test, verbose=2)
完整代码
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2'
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import models, layers, optimizers
#数据集加载及简单归一化
(x_train_all, y_train_all), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train_all, x_test = x_train_all / 255.0, x_test / 255.0
x_train, x_test = x_train_all[:4000], x_train_all[4000:5000]
y_train, y_test = y_train_all[:4000], y_train_all[4000:5000]
print(x_train.shape)
#权重初始化
def weight(units, factorize_k=None):
if factorize_k is None:
return keras.layers.Dense(units, activation=None, use_bias=False)
assert factorize_k < units
return keras.models.Sequential([
keras.layers.Dense(factorize_k, activation=None, use_bias=False),
keras.layers.Dense(units, activation=None, use_bias=False)
])
#定义Mogrifiercell
class MogrifierCell(keras.layers.LSTMCell):
def __init__(self, units, iters, factorize_k=None, **kwargs):
self.iters = iters
self.factorize_k = factorize_k
self.Q = weight(units, factorize_k)
self.R = weight(units, factorize_k) if iters > 1 else None
super().__init__(units, **kwargs)
def call(self, input_at_t, state_at_t, **kwargs):
if input_at_t is not None and state_at_t is not None:
shape = input_at_t.shape
*_, units = shape
h = state_at_t[0]
c = state_at_t[1]
x = tf.reshape(input_at_t, (-1, units))
h = tf.reshape(h, (-1, units))
for ind in range(self.iters):
if (ind % 2) != 0:
x = 2 * tf.sigmoid(self.Q(h)) * x
else:
h = 2 * tf.sigmoid(self.Q(x)) * h
input_at_t = tf.reshape(x, [-1,input_at_t.shape[-1]])
state_at_t = (tf.reshape(h, [-1,input_at_t.shape[-1]]), c)
return super().call(inputs=input_at_t, states=state_at_t)
# 构建网络模型
inputs = layers.Input(shape=(x_train.shape[1:]), name='inputs')
lstm1_out = layers.LSTM(units=256, return_sequences=True, name='lstm_1')(inputs)
Mogrifier =layers.RNN(MogrifierCell(256, 2), return_sequences=True, name='mogrifier_lstm')(lstm1_out)
lstm2_out = layers.LSTM(units=128, return_sequences=False,name='lstm_2')(Mogrifier)
outputs = layers.Dense(10, activation='softmax',name='dense')(lstm2_out)
lstm = keras.Model(inputs, outputs)
lstm.summary()
# 训练模型
lstm.compile(optimizer=keras.optimizers.Adam(0.001),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
history = lstm.fit(x_train, y_train, batch_size=32, epochs=1, validation_split=0.1)
# 评估模型:
lstm.evaluate(x_test, y_test, verbose=2)
参考文档
https://github.com/drk-knght/Mogrifier-LSTM