import tensorflow as tf
from sklearn import datasets
import numpy as np
from matplotlib import pyplot as plt
import os
############################################ 获取数据 ############################################
# iris数据集————鸢尾花卉数据集,150个样本,4个属性(花萼长度,花萼宽度,花瓣长度,花瓣宽度),3类
x_train1 = datasets.load_iris().data
y_train1 = datasets.load_iris().target
# mnist数据集————手写数字集,60000/10000的训练测试数据划分,28x28的灰度图片
(x_train2, y_train2), (x_test2, y_test2) = tf.keras.datasets.mnist.load_data()
x_train2, x_test2 = x_train2 / 255.0, x_test2 / 255.0
# fashion_mnist数据集————图像数据集,7万个不同商品,60000/10000的训练测试数据划分,28x28的灰度图片,10种类别
(x_train3, y_train3), (x_test3, y_test3) = tf.keras.datasets.fashion_mnist.load_data()
x_train3, x_test3 = x_train3 / 255.0, x_test3 / 255.0
# cifar10数据集————图像数据集,50000/10000的训练测试数据划分,32×32的RGB彩色图片,10个类别
(x_train3, y_train3), (x_test3, y_test3) = tf.keras.datasets.cifar10.load_data()
x_train3, x_test3 = x_train3 / 255.0, x_test3 / 255.0
############################################ 数据预处理 ############################################
# 可视化第一张图片
plt.imshow(x_train2[0], cmap='gray') # 绘制灰度图
plt.show()
# 打乱训练顺序
np.random.seed(116)
np.random.shuffle(x_train1)
np.random.seed(116)
np.random.shuffle(y_train1)
tf.random.set_seed(116)
############################################ 构建模型 ############################################
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
'''
拉直层:tf.keras.layers.Flatten()
拉直层可以变换张量的尺寸,把输入特征拉直为一维数组,是不含计算参数的层
全连接层:tf.keras.layers.Dense(神经元个数,
activation = "激活函数“,
kernel_regularizer = "正则化方式
)
activation——relu 、softmax、 sigmoid、 tanh等
kernel_regularizer——tf.keras.regularizers.l1()
tf.keras.regularizers.l2()
卷积层:tf.keras.layers.Conv2D(filter = 卷积核个数,
kernel_size = 卷积核尺寸,
strides = 卷积步长,
padding = ”valid“ or "same"
)
LSTM层:tf.keras.layers.LSTM()
'''
############################################ 编译模型 ############################################
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
'''
optimizer
例如: "sgd"
tf.optimizers.SGD(lr = 学习率
decay = 学习率衰减率,
momentum = 动量参数)
"adagrad"
tf.keras.optimizers.Adagrad(lr = 学习率,
decay = 学习率衰减率)
"adadelta"
tf.keras.optimizers.Adadelta(lr = 学习率,
decay = 学习率衰减率)
"adam"
tf.keras.optimizers.Adam(lr = 学习率,
decay = 学习率衰减率)
loss
例如: "mse"
tf.keras.losses.MeanSquaredError()
"sparse_categorical_crossentropy"
tf.keras.losses.SparseCatagoricalCrossentropy(from_logits = False)
False——转换为概率分布,True——直接输出
Metrics
例如: "accuracy"
y_ 和 y 都是数值,如y_ = [1] y = [1] #y_为真实值,y为预测值
"sparse_accuracy"
y_和y都是以独热码和概率分布表示,如y_ = [0, 1, 0], y = [0.256, 0.695, 0.048]
"sparse_categorical_accuracy"
y_是以数值形式给出,y是以概率分布给出,如y_ = [1], y = [0.256 0.695, 0.048]
'''
############################################ 断点续训 ############################################
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
print('-------------load the model-----------------')
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
history = model.fit(x_train2, y_train2, batch_size=32, epochs=500, validation_data=(x_test2, y_test2),
validation_freq=20,
callbacks=[cp_callback])
model.summary()
############################################ 参数提取 ############################################
# 设置输出样式
np.set_printoptions(threshold=np.inf)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
file.close()
############################################ 可视化 ############################################
# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()