Tensorflow搭建神经网络(保存模型、断点续训、手写数字实战)

网络八股扩展

总览:

  1. 自制数据集,解决本领域问题
  2. 数据增强,扩充数据集
  3. 断点续训,存取模型
  4. 参数提取,把参数存入文本
  5. acc/locc可视化,查看训练效果
  6. 应用程序,给图识物

自制数据集

自制数据集首先要自定义一个generateds函数,通过传入图片路径和标签文件来返回图片矩阵和对应的标签

标签文件长这样

mnist_train_jpg_xxxxx.txt:
	value[0]			value[1]
  0_5.jpg					5
  1_0.jpg					0
  2_4.jpg					4
  3_1.jpg					1
  4_9.jpg					9
def generateds(path, txt):
    f = open(txt, "r")
    contents = f.readlines()  # 读取文件的所有行
    f.close()
    x, y_ = [], []  # 建立空列表
    for content in contents:  # 逐行读出
        value = content.split()  # 以空格为分隔符
        img_path = path + value[0]
        img = Image.open(img_path)
        img = np.array(img.convert('L'))  # 转化为八位宽灰度值的np.array格式
        img = img / 255  # 数据归一化
        x.append(img)  # 将新数据粘贴到列表
        y_.append(value[1])
        print('Loading : ' + content)
    x = np.array(x)
    y_ = np.array(y_)
    y_ = y_.astype(np.int64)
    return x, y_

数据增强(增大数据量)

避免因拍照角度不同引起的错误

image_gen_train = tf.keras.preprocessing.image.ImageDataGenerator (
    rescale=所有数据将乘以该数值
    rotation_range=随机旋转角度数范围
    width_shift_range=随机宽度偏移量
    height_shift_range=随机高度偏移量
    horizontal_flip=是否随机水平翻转
    zoom_range=缩放的范围[1-n, 1+n]
)
image_gen_train.fit(x_train)

举个例子

image_gen_train = tf.keras.preprocessing.image.ImageDataGenerator (
    rescale=1. / 1.,
    rotation_range=45,
    width_shift_range=.15,
    height_shift_range=.15,
    horizontal_flip=False,
    zoom_range=0.5,
)
image_gen_train.fit(x_train)

需要注意的是这里的x_train需要的是一个四维数据,所以要对x_train进行reshape

x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
# (60000, 28, 28)		=>		(60000, 28, 28, 1)

与此同时model.fit也需要进行改动

model.fit(x_train, y_train, batch_size=32, ...)
# 改为
model.fit(image_gen_train.flow(x_train, y_train, batch_size=32), ...)

断点续训

读取保存模型
load_weights(路径文件名)

# Example
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'): #  模型在保存时会同步生成index文件,于是通过index文件的存在与否就可以判断模型是否保存过
  print('---------------load the model---------------')
  model.load_weights(checkpoint_save_path)
保存模型
tf.keras.callbacks.ModelCheckpoint(
	filepath=路径文件名,
  save_weights_only=True/False, #  是否进保存权值
  save_best_only=True/False  # 是否仅保存当前最优模型
)
history = model.fit(callbacks=[cp_callback])

# Example
cp_callback = tf.keras.callback.ModelCheckpoint(filepath=checkpoint_save_path,
                                  save_weights_only=True,
                                  save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5,
                   	validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])

参数提取

# 提取可训练参数
model.trainable_variables # 返回模型中可训练的参数

# 设置print输出格式
np.set_printoptions(threshold=超过多少省略显示)

np.set_printoptions(threshold=np.inf) # np.inf表示无限大

# Example
print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
  file.write(str(v.name)+'\n')
  file.write(str(v.shape)+'\n')
  file.write(str(v.numpy())+'\n')
file.close()

acc/loss可视化

history = model.fit(训练集数据, 训练集标签, batch_size=, eopchs=,
                    validation_split=用作测试数据的比例, calidation_data=测试集,
                    validation_freq=测试频率)

model.fit执行过程中,其返回值history同步记录了训练集loss、测试集loss、训练集准确率和测试集准确率

history:

  • 训练集loss: loss
  • 测试集loss: val_loss
  • 训练集准确率: sparse_categorical_accuracy
  • 测试集准确率: val_sparse_categorical_accuracy
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
画图
plt.subplot(1, 2, 1) # 将图像分为一行两列, 这句话画出第一列
plt.plot(acc, label="Training Accuracy")
plt.plot(val_acc, label="Validation Accuracy")
plt.title("Training and Validation Accuracy")
plt.legend()

plt.subplot(1, 2, 2) # 画出第二列
plt.plot(loss, label="Training Loss")
plt.plot(val_loss, label="Validation Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.show()

应用程序!给图识物

希望输入一张手写数字图片,输出识别结果

前向传播执行应用

predict(输入特征, batch_size=整数)
# 返回前向传播计算结果

# 复现模型
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10, activation'softmax')
])
# 加载参数
model.load_weights(model_save_path)
# 预测结果
result = model.predict(x_predict)
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值