曹健老师 TensorFlow2.1 —— 第四章 网络八股扩展

第一章

第二章

第三章

本章目的:扩展六步法功能,并实现应用. 

4.1 搭建网络八股总览

  • 利用自制数据集,解决本领域应用
  • 利用数据增强,解决数据量过少问题,扩展数据,提高泛化力
  • 利用断点续训,实时保存最优模型
  • 利用参数提取,可获取各层网络最优的参数,在任何平台实现前向推理复现模型,实现应用
  • 利用 acc 和 loss 曲线可视化,查看训练效果
  • 利用识物应用程序,输入神经网络一组新的、从未见过的特征,神经网络输出预测结果,实现学以致用

4.2 自制数据集

把 .data_load() 用自己的定义的函数替换掉

def generateds(图片路径, 标签文件):
"""
自制数据集 txt 文件中,value[0]对应特征,索引到每张图片,value[1]对应标签
把图片灰度值数据拼接到图片列表,标签数据拼接到标签列表,顺序一致即可.
"""
def generateds(path, txt):
    f = open(txt, 'r')
    contents = f.readlines()
    f.close()
    x, y_ = [], []
    for content in contents:
        value = content.split()     # 以空格分开,图片路径为value[0],标签文件为value[1],存入列表
        img_path = path + value[0]  # 图片路径 + 图片名 拼接出图片的索引路径
        img = Image.open(img_path)  # 读入图片
        img = np.array(img.convert('L'))    # 图片变为 8 位宽度的灰度值 np.array 格式
        img = img / 255             # 数据归一化
        x.append(img)
        y_.append(value[1])
        print("Loading : " + content)

    x = np.array(x)
    y_ = np.array(y_)
    y_ = y_.astype(np.int64)
    return x, y_
    # 返回输入特征和标签

4.3 数据增强

对图像的增强就是对图像的简单形变,用来应对因拍照角度不同引起的图片变形.

image_gen_train = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale = 所有数据将乘上该数值,
    rotation_range = 随机旋转角度数范围,
    width_shift_range = 随机宽度偏移量,
    height_shift_range = 随机高度偏移量,
    水平翻转: horizontal_flip = 是否随机水平翻转,
    随机缩放: zoom_range = 随机缩放的范围 [1-n, 1+n]
)
image_gen_train.fit(x_train)    # x_train 应为四维数据,第四维是 RGB 通道

4.4 断点续训

在进行神经网络训练过程中由于一些因素导致训练无法进行,需要保存当前的训练结果,下次可以接着训练.

  • 读取模型
""" load_weights(路径文件名) """
checkpoint_save_path = "./checkpoint/mnist.ckpt"    # 存放模型的路径和文件名,命名为 ckpt 文件
# 生成 ckpt 文件时,会同步生成索引表
# 因此,可以通过判断索引表存在与否,去判断模型是否保存完全
if os.path.exists(checkpoint_save_path + ".index"):
    model.load_weight(checkpoint_save_path)
  • 保存模型
tf.keras.callbacks.ModelCheckpoint(
    filepath = 路径文件名,
    save_weights_only = True / False,    # 是否只保留模型参数
    save_best_only = True / False        # 是否只保留最优结果
)
history = model.fit(callbacks = [cp_callback])
# history 中存储了 loss 和 metrics 的结果,用于可视化

4.5 参数提取

  • 提取模型中可训练参数
model.trainable_variables # 返回可训练参数
  • print 函数可直接将上一步提取的参数打印出来,不过其中会有很多数据被省略号替换掉,通过设置 print 函数的打印效果
“”“ np.set_printoptions(threshold=超过多少省略显示) ”“”
np.set_printoptions(threshold=np.inf)    # np.inf 表示正无穷
  • 可以利用 for 循环,把所有可训练参数存入文本
print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

4.6 acc & loss 可视化

  • 在 model.fit 执行训练过程时,同步记录了训练集 loss 、测试集 loss 、训练集 acc 和测试集 acc .可以用 history.history() 提取出.
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
  • 提取后进行可视化
plt.subplot(1, 2, 1)        # subplot 将图像分为一行两列,现在画出第一列
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)        # 画出第二列
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

4.7 应用程序 ---- 给图识物

希望得到的模型可以识别手写数字图片. 利用向前传播执行应用:

“”“ predict(输入特征, batch_size=整数) ”“”
# 根据输入特征,返回向前传播计算结果

具体步骤为:

( 1 ) 复现模型 ( 前向传播 ) 

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu')
    tf.keras.layers.Dense(10, activation='softmax')
])

( 2 ) 加载参数

model.load_weights(model_save_path)

( 3 ) 输入特征预处理:由于训练模型时所用图片是 28 行 28 列的灰度图,但输入特征为任意尺寸的图片,因此需要先进行预处理.

image_path = input("the path of test picture:")
img = Image.open(image_path)
img = img.resize((28, 28), Image.ANTIALIAS)     # resize 为 28 行 28 列标准尺寸
img_arr = np.array(img.convert('L'))            # 转换为灰度图

“”“ 方法一 ”“”
# 训练时使用的是黑底白字,输入特征为白底黑自字,因此灰度值取反
img_arr = 255 - img_arr     

“”“ 方法二 ”“”
# 对于手写数字识别,还可以让输入图片变为只有黑色和白色的高对比度图片
# 使用嵌套 for 循环,遍历输入图像的每一个像素点,
for i in range(28):
    for j in range(28):
        if img_arr[i][j] < 200:
            img_arr[i][j] = 255    # 灰度值小于 200 的变为纯白色
        else:
            img_arr[i][j] = 0      # 其余变为纯黑色
# 该方法在保留图片信息的同时,滤去了背景噪声
# 以上两种方法选一即可

img_arr = img_arr / 255.0     # 归一化

( 4 ) 为了满足神经网络输入特征的 shape ( 图片总数, 宽, 高 ) ( 第一个维度是 batch ), 应为 image 的前面添加一个维度,由 28 行 28 列的二维数据变为一个 28 行 28 列的三维数据.

x_predict = img_arr[tf.newaxis, ...]

( 5 ) 预测结果

result = model.predict(x_predict)
pred = tf.argmax(result, axis=1)    # 输出最大概率值
tf.print(pred)    # 返回预测结果

 

##################### 完整八股  #####################

# 入门级神经网络可利用完善后的八股实现
# 完整的完善后八股
import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt

np.set_printoptions(threshold=np.inf)

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()

print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

###############################################    show   ###############################################

# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1)        # subplot 将图像分为一行两列,现在画出第一列
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)        # 画出第二列
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值