构建神经网络进行时装模型训练与预测
class SinleNN(object):
# 建立神经网络模型
model = keras.Sequential([
# 将输入数据的形状进行修改为扁平化
keras.layers.Flatten(input_shape=[28, 28]),
# 定义隐藏层,128个神经元的网络层
keras.layers.Dense(128, activation=tf.nn.relu),
# 10个类别的分类问题,输出神经元的个数和输出类别个数相同
keras.layers.Dense(10, activation=tf.nn.softmax)
])
def __init__(self):
# 返回两个元组 (50000, 784) (50000, 1)
(self.x_train, self.y_train), (self.x_test, self.y_test) = fashion_mnist.load_data()
# 进行数据的归一化操作 必须进行归一化,不然无法拟合
self.x_train = self.x_train / 255.0
self.x_test = self.x_test / 255.0
def singlenn_compile(self):
"""
编译模型优化器 使用的损失是交叉熵损失
sparse_categorical_crossentropy 自动进行one-hot编码后求交叉熵
:return:
"""
SinleNN.model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.sparse_categorical_crossentropy, metrics=['accuracy'])
return None
def sinlenn_fit(self):
"""
进行fit训练 tensorboard 和 保存模型都在callback中
:return:
"""
# fit当中添加回调函数,记录训练模型过程
# modelcheck = keras.callbacks.ModelCheckpoint(
# filepath="./ckpt/singlenn_{epoch:02d}-{val_acc:.2f}.h5",
# monitor="val_acc", # 指定保存损失还是准确率
# save_best_only=True, # 指定保存的这一次比上一次好
# save_weights_only=True, # 保存模型参数
# mode="auto",
# period=1 # 每次迭代保存一次
# )
# 调用tensorboard回调函数
board = keras.callbacks.TensorBoard(
log_dir="./graph/",
write_graph=True,
)
# 训练样本的特征值和目标值 epochs迭代的次数 batch_size 一次所选的个数
SinleNN.model.fit(self.x_train, self.y_train, epochs=5, batch_size=64, callbacks=[board])
return None
def single_evaluate(self):
# 评估模型测试效果
test_loss, test_acc = SinleNN.model.evaluate(self.x_test, self.y_test)
print(test_loss, test_acc)
return None
def single_predict(self):
"""
加载模型 预测结果
:return:
"""
# 首先加载模型
if os.path.exists("./ckpt/*.h5"):
SinleNN.model.load_weights("./ckpt/SingleNN.h5")
predictions = SinleNN.model.predict(self.x_test)
return predictions
if name == ‘main’:
snn = SinleNN()
snn.singlenn_compile()
snn.sinlenn_fit()
snn.single_evaluate()
# 保存模型 使用ckpt进行保存 只需目录加名字
# SinleNN.model.save_weights("./ckpt/SingleNN.h5")
# 进行模型预测
# predictions = snn.single_predict()
# 对第二个维度 取最大值
# print(np.argmax(predictions, axis=1))
# 保存成h5模型,读取和加载速度比较快,一个文件就够