激活函数有很大的作用,这里推荐tanh,leaky_relu(重要结论)
from tensorflow.keras import layers
from matplotlib import pyplot as pt
from sklearn.metrics import confusion_matrix
import tensorflow as tf #导入tensorflow库
mnist=tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data() #加载数据
pt.matshow(x_train[1])
pt.matshow(x_train[2])
<matplotlib.image.AxesImage at 0x1e3a0ee11c8>
x_train = tf.keras.utils.normalize(x_train,axis=1)#模拟2个样本,50个特征
x_test=tf.keras.utils.normalize(x_test,axis=1)
x_train = x_train.reshape([-1,28*28])
x_test = x_test.reshape([-1,28*28])
x_train = tf.cast(x_train,dtype = tf.float32)
x_test = tf.cast(x_test,dtype = tf.float32)
print(x_train[1])
print(y_train[1])
tf.Tensor(
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0.08216044 0.2286589 0.3728098 0.30506548 0.08583808
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.08087653 0.3834154
0.3624028 0.37133625 0.4835 0.4068725 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0.08861609 0.3824786 0.40758026 0.3624028 0.35218
0.44704565 0.43262392 0.06832372 0.00859123 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0.01621743 0.095788 0.36759266
0.4246018 0.40758026 0.3624028 0.2976584 0.16116667 0.43262392
0.3032614 0.17468832 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0.26434407 0.4023096 0.41354173 0.4246018 0.40758026
0.3624028 0.37133625 0.18419048 0.32446793 0.3032614 0.23912254
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.08411834 0.38597476
0.40390608 0.41518277 0.32013628 0.18365276 0.36384088 0.33597088
0.09017659 0.13562417 0.30565873 0.2405544 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0.07427511 0.39255226 0.40867916 0.4023096 0.2937459
0.02021913 0.12082418 0.17401086 0.03094469 0. 0.
0.3032614 0.34794477 0.12263192 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.04890249 0.2553207
0.41729295 0.37786606 0.33206508 0.13784724 0. 0.
0. 0. 0. 0. 0.3032614 0.36083162
0.40468535 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0.00966301 0.22906953 0.38994434 0.39585102 0.11514373
0.03033287 0.04594908 0. 0. 0. 0.
0. 0. 0.3032614 0.36083162 0.4782645 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.07868449
0.32430068 0.38994434 0.10391089 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.3032614 0.36083162 0.4782645 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0.27332506 0.3255876 0.29400566
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.30565873 0.36226347
0.48071715 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.33960736 0.3395857 0.32430068 0.1733086 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.3032614 0.36083162 0.3629905 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.379824 0.34786826
0.29598874 0.03868495 0. 0. 0. 0.
0. 0. 0. 0. 0.01343056 0.23176281
0.3032614 0.2663281 0.02943166 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.379824 0.34786826 0.28698036 0.
0. 0. 0. 0. 0. 0.
0. 0.0103149 0.25134325 0.43262392 0.2696989 0.10166287
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.379824 0.34786826 0.1866016 0. 0. 0.
0. 0. 0. 0. 0.0690291 0.24313682
0.4835 0.29699975 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.38429254 0.34924868
0.28955418 0. 0. 0. 0. 0.
0. 0.18365276 0.3422693 0.3728098 0.31082144 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.379824 0.34786826 0.32043996 0.22592013
0.0791702 0.04703054 0.13569966 0.29210487 0.37910873 0.40758026
0.3206977 0.24608393 0.10744444 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.379824 0.34786826 0.32430068 0.38994434 0.37770784 0.34867468
0.4023096 0.41354173 0.4246018 0.31575388 0.18695381 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.12511851 0.2747055
0.32430068 0.38994434 0.41729295 0.40867916 0.4023096 0.382362
0.24431452 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0.03451074 0.16472416 0.38994434
0.41729295 0.40867916 0.2251018 0.06071843 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ], shape=(784,), dtype=float32)
0
y_train_onehot = tf.one_hot(y_train,10,axis=1) #转热点,10个标签
print(y_train_onehot[1])
y_test_onehot = tf.one_hot(y_test,10,axis=1) #转热点,10个标签
print(y_test_onehot[1])
tf.Tensor([1. 0. 0. 0. 0. 0. 0. 0. 0. 0.], shape=(10,), dtype=float32)
tf.Tensor([0. 0. 1. 0. 0. 0. 0. 0. 0. 0.], shape=(10,), dtype=float32)
w1 = tf.Variable(tf.random.truncated_normal([784,128],stddev=0.2))#第1层初始化
w1 = tf.cast(w1,dtype = tf.float32)
b1 = tf.zeros([128])
w2 = tf.Variable(tf.random.truncated_normal([128,130],stddev=0.2))#第2层初始化
w2 = tf.cast(w2,dtype = tf.float32)
b2 = tf.zeros([130])
w3 = tf.Variable(tf.random.truncated_normal([130,10],stddev=0.2))#第3层初始化
b3 = tf.zeros([10])
loss_list = []
for i in range(800):
lr = 0.01
print(i)
with tf.GradientTape() as tape:
tape.watch([w1,w2,w3])
o1 = tf.matmul(x_train,w1) + b1 #也可使用广播机制拓展b1, tf.broadcast_to(b1, [x.shape[0], 256])
o1 = tf.nn.leaky_relu(o1) #激活函数有很大的作用,记住,原始数据有负数,就不要用纯正数的激活函数,这里推荐tanh,leaky_relu
o2 = tf.matmul(o1,w2) + b2
o2 = tf.nn.leaky_relu(o2)
o3 =tf.matmul(o2,w3) + b3
loss = tf.nn.softmax_cross_entropy_with_logits(logits = o3,labels = y_train_onehot)
loss = tf.reduce_mean(loss)
grad = tape.gradient(loss,[w1,w2,w3])
w1 = w1.assign_sub(lr * grad[0])
w2 = w2.assign_sub(lr * grad[1])
w3 = w3.assign_sub(lr * grad[2])
loss_sum = tf.reduce_mean(loss)
loss_list.append(float(loss_sum))
print(w3)
pt.scatter([i for i in range(800)],loss_list)
<matplotlib.collections.PathCollection at 0x1e3a03e9348>
#混淆矩阵
o3 = tf.argmax(o3,axis = 1)
#o3 = tf.one_hot(o3,10,axis=1)
cm = confusion_matrix(o3,y_train)
cm
对角线为预测正确数
array([[5463, 3, 108, 34, 18, 105, 72, 62, 59, 73],
[ 3, 6422, 81, 87, 47, 37, 59, 120, 225, 79],
[ 52, 79, 4775, 257, 95, 149, 172, 75, 118, 30],
[ 22, 36, 243, 4920, 6, 445, 12, 50, 344, 72],
[ 16, 12, 138, 19, 4980, 73, 75, 130, 84, 418],
[ 158, 24, 30, 309, 31, 3893, 116, 13, 245, 51],
[ 100, 21, 171, 46, 144, 127, 5348, 12, 80, 35],
[ 24, 16, 94, 115, 58, 52, 5, 5326, 89, 413],
[ 65, 112, 253, 290, 73, 402, 47, 88, 4309, 167],
[ 20, 17, 65, 54, 390, 138, 12, 389, 298, 4611]],
dtype=int64)
实际值
y_train[4]
9
预测值
o3[4]
<tf.Tensor: shape=(), dtype=int64, numpy=9>
#测试的混淆矩阵
t1 = tf.matmul(x_test,w1) + b1 #也可使用广播机制拓展b1, tf.broadcast_to(b1, [x.shape[0], 256])
t1 = tf.nn.relu(t1)
t2 = tf.matmul(t1,w2) + b2
t2 = tf.nn.relu(t2)
t3 =tf.matmul(t2,w3) + b3
t3 = tf.nn.softmax(t3)
loss =tf.square(t3-y_test_onehot)
loss = tf.reduce_mean(loss)
loss
t3 = tf.argmax(t3,axis = 1)
cm = confusion_matrix(y_test,t3)
cm
array([[ 904, 0, 22, 3, 0, 25, 21, 4, 1, 0],
[ 0, 1108, 13, 4, 0, 2, 3, 1, 4, 0],
[ 19, 17, 881, 31, 8, 2, 28, 17, 24, 5],
[ 3, 22, 79, 770, 0, 76, 7, 22, 26, 5],
[ 1, 23, 42, 1, 739, 13, 53, 23, 12, 75],
[ 11, 9, 37, 57, 1, 684, 27, 16, 37, 13],
[ 13, 8, 45, 2, 3, 20, 864, 0, 3, 0],
[ 8, 44, 36, 3, 10, 4, 5, 851, 11, 56],
[ 18, 60, 42, 46, 5, 65, 36, 25, 642, 35],
[ 16, 24, 16, 4, 44, 15, 21, 79, 10, 780]],
dtype=int64)