全连接简单minist操作

激活函数有很大的作用,这里推荐tanh,leaky_relu(重要结论)

from tensorflow.keras import layers
from matplotlib import pyplot as pt
from sklearn.metrics import confusion_matrix
import tensorflow as tf  #导入tensorflow库
mnist=tf.keras.datasets.mnist  
(x_train,y_train),(x_test,y_test)=mnist.load_data()  #加载数据
pt.matshow(x_train[1])
pt.matshow(x_train[2])
<matplotlib.image.AxesImage at 0x1e3a0ee11c8>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Z1IebWrd-1640421020468)(output_1_1.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fYwRrLKF-1640421020469)(output_1_2.png)]

x_train = tf.keras.utils.normalize(x_train,axis=1)#模拟2个样本,50个特征
x_test=tf.keras.utils.normalize(x_test,axis=1)
x_train = x_train.reshape([-1,28*28])
x_test = x_test.reshape([-1,28*28])
x_train = tf.cast(x_train,dtype = tf.float32)
x_test = tf.cast(x_test,dtype = tf.float32)
print(x_train[1])
print(y_train[1])

tf.Tensor(
[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.08216044 0.2286589  0.3728098  0.30506548 0.08583808
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.08087653 0.3834154
 0.3624028  0.37133625 0.4835     0.4068725  0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.08861609 0.3824786  0.40758026 0.3624028  0.35218
 0.44704565 0.43262392 0.06832372 0.00859123 0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.01621743 0.095788   0.36759266
 0.4246018  0.40758026 0.3624028  0.2976584  0.16116667 0.43262392
 0.3032614  0.17468832 0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.26434407 0.4023096  0.41354173 0.4246018  0.40758026
 0.3624028  0.37133625 0.18419048 0.32446793 0.3032614  0.23912254
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.08411834 0.38597476
 0.40390608 0.41518277 0.32013628 0.18365276 0.36384088 0.33597088
 0.09017659 0.13562417 0.30565873 0.2405544  0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.07427511 0.39255226 0.40867916 0.4023096  0.2937459
 0.02021913 0.12082418 0.17401086 0.03094469 0.         0.
 0.3032614  0.34794477 0.12263192 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.04890249 0.2553207
 0.41729295 0.37786606 0.33206508 0.13784724 0.         0.
 0.         0.         0.         0.         0.3032614  0.36083162
 0.40468535 0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.00966301 0.22906953 0.38994434 0.39585102 0.11514373
 0.03033287 0.04594908 0.         0.         0.         0.
 0.         0.         0.3032614  0.36083162 0.4782645  0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.07868449
 0.32430068 0.38994434 0.10391089 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.3032614  0.36083162 0.4782645  0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.27332506 0.3255876  0.29400566
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.30565873 0.36226347
 0.48071715 0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.33960736 0.3395857  0.32430068 0.1733086  0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.3032614  0.36083162 0.3629905  0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.379824   0.34786826
 0.29598874 0.03868495 0.         0.         0.         0.
 0.         0.         0.         0.         0.01343056 0.23176281
 0.3032614  0.2663281  0.02943166 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.379824   0.34786826 0.28698036 0.
 0.         0.         0.         0.         0.         0.
 0.         0.0103149  0.25134325 0.43262392 0.2696989  0.10166287
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.379824   0.34786826 0.1866016  0.         0.         0.
 0.         0.         0.         0.         0.0690291  0.24313682
 0.4835     0.29699975 0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.38429254 0.34924868
 0.28955418 0.         0.         0.         0.         0.
 0.         0.18365276 0.3422693  0.3728098  0.31082144 0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.379824   0.34786826 0.32043996 0.22592013
 0.0791702  0.04703054 0.13569966 0.29210487 0.37910873 0.40758026
 0.3206977  0.24608393 0.10744444 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.379824   0.34786826 0.32430068 0.38994434 0.37770784 0.34867468
 0.4023096  0.41354173 0.4246018  0.31575388 0.18695381 0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.12511851 0.2747055
 0.32430068 0.38994434 0.41729295 0.40867916 0.4023096  0.382362
 0.24431452 0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.03451074 0.16472416 0.38994434
 0.41729295 0.40867916 0.2251018  0.06071843 0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.        ], shape=(784,), dtype=float32)
0
y_train_onehot = tf.one_hot(y_train,10,axis=1) #转热点,10个标签
print(y_train_onehot[1])
y_test_onehot = tf.one_hot(y_test,10,axis=1) #转热点,10个标签
print(y_test_onehot[1])
tf.Tensor([1. 0. 0. 0. 0. 0. 0. 0. 0. 0.], shape=(10,), dtype=float32)
tf.Tensor([0. 0. 1. 0. 0. 0. 0. 0. 0. 0.], shape=(10,), dtype=float32)

w1 = tf.Variable(tf.random.truncated_normal([784,128],stddev=0.2))#第1层初始化
w1 = tf.cast(w1,dtype = tf.float32)
b1 = tf.zeros([128])
w2 = tf.Variable(tf.random.truncated_normal([128,130],stddev=0.2))#第2层初始化
w2 = tf.cast(w2,dtype = tf.float32)
b2 = tf.zeros([130])
w3 = tf.Variable(tf.random.truncated_normal([130,10],stddev=0.2))#第3层初始化
b3 = tf.zeros([10])

loss_list = []

for i in range(800):
    lr = 0.01
    print(i)
    with tf.GradientTape() as tape: 
        tape.watch([w1,w2,w3])
        o1 = tf.matmul(x_train,w1) + b1 #也可使用广播机制拓展b1, tf.broadcast_to(b1, [x.shape[0], 256])
        o1 = tf.nn.leaky_relu(o1)   #激活函数有很大的作用,记住,原始数据有负数,就不要用纯正数的激活函数,这里推荐tanh,leaky_relu
        o2 = tf.matmul(o1,w2) + b2
        o2 = tf.nn.leaky_relu(o2)
        o3 =tf.matmul(o2,w3) + b3
        loss = tf.nn.softmax_cross_entropy_with_logits(logits = o3,labels = y_train_onehot)
        loss = tf.reduce_mean(loss)
        grad = tape.gradient(loss,[w1,w2,w3])
        w1 = w1.assign_sub(lr * grad[0])
        w2 = w2.assign_sub(lr * grad[1])
        w3 = w3.assign_sub(lr * grad[2])
        loss_sum = tf.reduce_mean(loss) 
        loss_list.append(float(loss_sum))
        print(w3)
pt.scatter([i for i in range(800)],loss_list)
<matplotlib.collections.PathCollection at 0x1e3a03e9348>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-R4eJ9b9M-1640421059282)(output_5_2.png)]

#混淆矩阵
o3 = tf.argmax(o3,axis = 1)
#o3 = tf.one_hot(o3,10,axis=1)
cm = confusion_matrix(o3,y_train)
cm

对角线为预测正确数

array([[5463,    3,  108,   34,   18,  105,   72,   62,   59,   73],
       [   3, 6422,   81,   87,   47,   37,   59,  120,  225,   79],
       [  52,   79, 4775,  257,   95,  149,  172,   75,  118,   30],
       [  22,   36,  243, 4920,    6,  445,   12,   50,  344,   72],
       [  16,   12,  138,   19, 4980,   73,   75,  130,   84,  418],
       [ 158,   24,   30,  309,   31, 3893,  116,   13,  245,   51],
       [ 100,   21,  171,   46,  144,  127, 5348,   12,   80,   35],
       [  24,   16,   94,  115,   58,   52,    5, 5326,   89,  413],
       [  65,  112,  253,  290,   73,  402,   47,   88, 4309,  167],
       [  20,   17,   65,   54,  390,  138,   12,  389,  298, 4611]],
      dtype=int64)
实际值
y_train[4]
9
预测值
o3[4]
<tf.Tensor: shape=(), dtype=int64, numpy=9>



#测试的混淆矩阵
t1 = tf.matmul(x_test,w1) + b1 #也可使用广播机制拓展b1, tf.broadcast_to(b1, [x.shape[0], 256])
t1 = tf.nn.relu(t1)
t2 = tf.matmul(t1,w2) + b2
t2 = tf.nn.relu(t2)
t3 =tf.matmul(t2,w3) + b3
t3 = tf.nn.softmax(t3)
loss =tf.square(t3-y_test_onehot)
loss = tf.reduce_mean(loss)
loss
t3 =  tf.argmax(t3,axis = 1)
cm = confusion_matrix(y_test,t3)
cm
array([[ 904,    0,   22,    3,    0,   25,   21,    4,    1,    0],
       [   0, 1108,   13,    4,    0,    2,    3,    1,    4,    0],
       [  19,   17,  881,   31,    8,    2,   28,   17,   24,    5],
       [   3,   22,   79,  770,    0,   76,    7,   22,   26,    5],
       [   1,   23,   42,    1,  739,   13,   53,   23,   12,   75],
       [  11,    9,   37,   57,    1,  684,   27,   16,   37,   13],
       [  13,    8,   45,    2,    3,   20,  864,    0,    3,    0],
       [   8,   44,   36,    3,   10,    4,    5,  851,   11,   56],
       [  18,   60,   42,   46,    5,   65,   36,   25,  642,   35],
       [  16,   24,   16,    4,   44,   15,   21,   79,   10,  780]],
      dtype=int64)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小蜗笔记

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值