文章目录
CIFAR10自定义网络实战
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
# 数据预处理
def preprocess(x,y):
# [-1,1]
x = 2 * tf.cast(x,dtype=tf.float32) / 255. - 1
y = tf.cast(y,dtype=tf.int32)
return x,y
batchsz = 128
# 数据集的加载
# x[b,32,32,3] y[b,1]
(x,y),(x_val,y_val) = datasets.cifar10.load_data()
# 消去[b,1]的1这个维度
y = tf.squeeze(y)
y_val = tf.squeeze(y_val)
y = tf.one_hot(y,depth=10)
y_val = tf.one_hot(y_val,depth=10)
print('datasets:',x.shape,y.shape,x.min(),x.max())
# datasets: (50000, 32, 32, 3) (50000, 10) 0 255
# 构建两个数据集
train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.map(preprocess).shuffle(10000).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices((x_val,y_val))
test_db = test_db.map(preprocess).batch(batchsz)
sample = next(iter(train_db))
print('batch:',sample[0].shape,sample[1].shape)
# 创建自己的层
# replace standard layers.Dense
class MyDense(layers.Layer):
def __init__(self,inp_dim,outp_dim):
super(MyDense,self).__init__()
self.kernel = self.add_variable('w',[inp_dim,outp_dim])
# self.bias = self.add_variable('b',[outp_dim])
# 构建前向传播
def call(self,input,training = None):
x = input @ self.kernel
return x
# 构建自定义网络(5层)
class MyNetwork(keras.Model):
def __init__(self):
super(MyNetwork,self).__init__()
# 优化-使参数变大-但容易造成过拟合
self.fc1 = MyDense(32*32*3,256)
self.fc2 = MyDense(256,128)
self.fc3 = MyDense(128,64)
self.fc4 = MyDense(64,32)
self.fc5 = MyDense(32,10)
def call(self,inputs,training=None):
"""
:param inputs: [b,32,32,3]
:param training:
:return:
"""
# 打平操作
x = tf.reshape(inputs,[-1,32*32*3])
x = self.fc1(x)
x = tf.nn.relu(x)
x = self.fc2(x)
x = tf.nn.relu(x)
x = self.fc3(x)
x = tf.nn.relu(x)
x = self.fc4(x)
x = tf.nn.relu(x)
# x[b,32]->[b,10]
x = self.fc5(x)
return x
network = MyNetwork()
network.compile(optimizer = optimizers.Adam(lr = 1e-3),
loss = tf.losses.CategoricalCrossentropy(from_logits=True),
metrics = ['accuracy'])
network.fit(train_db,epochs=15,validation_data = test_db,validation_freq=1)
# 保存模型权值
network.evaluate(test_db)
network.save_weights('ckpt/weights.ckpt')
del network
print('saved to ckpt/weights.ckpt')
network = MyNetwork()
network.compile(optimizer = optimizers.Adam(lr = 1e-3),
loss = tf.losses.CategoricalCrossentropy(from_logits=True),
metrics = ['accuracy'])
# 加载模型权值
network.load_weights('ckpt/weights.ckpt')
print('load weights from file')
network.evaluate(test_db)
Epoch 14/15
1/391 [..............................] - ETA: 2:59 - loss: 0.6248 - accuracy: 0.8047
8/391 [..............................] - ETA: 24s - loss: 0.6025 - accuracy: 0.7744
14/391 [>.............................] - ETA: 15s - loss: 0.5613 - accuracy: 0.7952
20/391 [>.............................] - ETA: 11s - loss: 0.5669 - accuracy: 0.7969
26/391 [>.............................] - ETA: 9s - loss: 0.5580 - accuracy: 0.8029
32/391 [=>............................] - ETA: 8s - loss: 0.5757 - accuracy: 0.7932
38/391 [=>............................] - ETA: 7s - loss: 0.5719 - accuracy: 0.7926
44/391 [==>...........................] - ETA: 6s - loss: 0.5721 - accuracy: 0.7933
50/391 [==>...........................] - ETA: 5s - loss: 0.5669 - accuracy: 0.7962
56/391 [===>..........................] - ETA: 5s - loss: 0.5710 - accuracy: 0.7939
62/391 [===>..........................] - ETA: 5s - loss: 0.5740 - accuracy: 0.7941
68/391 [====>.........................] - ETA: 4s - loss: 0.5731 - accuracy: 0.7945
75/391 [====>.........................] - ETA: 4s - loss: 0.5753 - accuracy: 0.7922
81/391 [=====>........................] - ETA: 4s - loss: 0.5745 - accuracy: 0.7936
88/391 [=====>........................] - ETA: 4s - loss: 0.5727 - accuracy: 0.7936
94/391 [======>.......................] - ETA: 3s - loss: 0.5742 - accuracy: 0.7927
101/391 [======>.......................] - ETA: 3s - loss: 0.5736 - accuracy: 0.7932
107/391 [=======>......................] - ETA: 3s - loss: 0.5724 - accuracy: 0.7934
114/391 [=======>......................] - ETA: 3s - loss: 0.5749 - accuracy: 0.7926
120/391 [========>.....................] - ETA: 3s - loss: 0.5757 - accuracy: 0.7934
126/391 [========>.....................] - ETA: 3s - loss: 0.5722 - accuracy: 0.7951
133/391 [=========>....................] - ETA: 3s - loss: 0.5721 - accuracy: 0.7955
139/391 [=========>....................] - ETA: 2s - loss: 0.5717 - accuracy: 0.7955
146/391 [==========>...................] - ETA: 2s - loss: 0.5715 - accuracy: 0.7954
152/391 [==========>...................] - ETA: 2s - loss: 0.5694 - accuracy: 0.7959
159/391 [===========>..................] - ETA: 2s - loss: 0.5688 - accuracy: 0.7957
166/391 [===========>..................] - ETA: 2s - loss: 0.5699 - accuracy: 0.7948
173/391 [============>.................] - ETA: 2s - loss: 0.5699 - accuracy: 0.7953
180/391 [============>.................] - ETA: 2s - loss: 0.5691 - accuracy: 0.7954
187/391 [=============>................] - ETA: 2s - loss: 0.5686 - accuracy: 0.7957
193/391 [=============>................] - ETA: 2s - loss: 0.5687 - accuracy: 0.7956
200/391 [==============>...............] - ETA: 2s - loss: 0.5694 - accuracy: 0.7952
207/391 [==============>...............] - ETA: 1s - loss: 0.5688 - accuracy: 0.7954
214/391 [===============>..............] - ETA: 1s - loss: 0.5673 - accuracy: 0.7951
221/391 [===============>..............] - ETA: 1s - loss: 0.5672 - accuracy: 0.7953
228/391 [================>.............] - ETA: 1s - loss: 0.5661 - accuracy: 0.7958
234/391 [================>.............] - ETA: 1s - loss: 0.5651 - accuracy: 0.7959
240/391 [=================>............] - ETA: 1s - loss: 0.5638 - accuracy: 0.7964
247/391 [=================>............] - ETA: 1s - loss: 0.5638 - accuracy: 0.7962
254/391 [==================>...........] - ETA: 1s - loss: 0.5627 - accuracy: 0.7971
261/391 [===================>..........] - ETA: 1s - loss: 0.5635 - accuracy: 0.7968
268/391 [===================>..........] - ETA: 1s - loss: 0.5642 - accuracy: 0.7966
275/391 [====================>.........] - ETA: 1s - loss: 0.5638 - accuracy: 0.7969
282/391 [====================>.........] - ETA: 1s - loss: 0.5633 - accuracy: 0.7972
289/391 [=====================>........] - ETA: 1s - loss: 0.5626 - accuracy: 0.7973
296/391 [=====================>........] - ETA: 0s - loss: 0.5625 - accuracy: 0.7973
302/391 [======================>.......] - ETA: 0s - loss: 0.5629 - accuracy: 0.7968
309/391 [======================>.......] - ETA: 0s - loss: 0.5641 - accuracy: 0.7967
318/391 [=======================>......] - ETA: 0s - loss: 0.5652 - accuracy: 0.7964
332/391 [========================>.....] - ETA: 0s - loss: 0.5661 - accuracy: 0.7960
347/391 [=========================>....] - ETA: 0s - loss: 0.5674 - accuracy: 0.7959
362/391 [==========================>...] - ETA: 0s - loss: 0.5676 - accuracy: 0.7957
376/391 [===========================>..] - ETA: 0s - loss: 0.5684 - accuracy: 0.7957
389/391 [============================>.] - ETA: 0s - loss: 0.5698 - accuracy: 0.7956
391/391 [==============================] - 4s 10ms/step - loss: 0.5697 - accuracy: 0.7956 - val_loss: 1.9200 - val_accuracy: 0.5195
Epoch 15/15
1/391 [..............................] - ETA: 2:55 - loss: 0.6455 - accuracy: 0.7812
8/391 [..............................] - ETA: 24s - loss: 0.5190 - accuracy: 0.8135
15/391 [>.............................] - ETA: 14s - loss: 0.5051 - accuracy: 0.8161
22/391 [>.............................] - ETA: 10s - loss: 0.4930 - accuracy: 0.8224
29/391 [=>............................] - ETA: 8s - loss: 0.4935 - accuracy: 0.8217
36/391 [=>............................] - ETA: 7s - loss: 0.4941 - accuracy: 0.8238
43/391 [==>...........................] - ETA: 6s - loss: 0.4999 - accuracy: 0.8212
50/391 [==>...........................] - ETA: 5s - loss: 0.5044 - accuracy: 0.8181
57/391 [===>..........................] - ETA: 5s - loss: 0.5097 - accuracy: 0.8177
64/391 [===>..........................] - ETA: 4s - loss: 0.5112 - accuracy: 0.8174
71/391 [====>.........................] - ETA: 4s - loss: 0.5097 - accuracy: 0.8168
78/391 [====>.........................] - ETA: 4s - loss: 0.5115 - accuracy: 0.8172
85/391 [=====>........................] - ETA: 4s - loss: 0.5161 - accuracy: 0.8148
92/391 [======>.......................] - ETA: 3s - loss: 0.5176 - accuracy: 0.8145
99/391 [======>.......................] - ETA: 3s - loss: 0.5187 - accuracy: 0.8149
106/391 [=======>......................] - ETA: 3s - loss: 0.5168 - accuracy: 0.8155
113/391 [=======>......................] - ETA: 3s - loss: 0.5177 - accuracy: 0.8148
119/391 [========>.....................] - ETA: 3s - loss: 0.5190 - accuracy: 0.8147
125/391 [========>.....................] - ETA: 3s - loss: 0.5164 - accuracy: 0.8159
132/391 [=========>....................] - ETA: 2s - loss: 0.5162 - accuracy: 0.8159
139/391 [=========>....................] - ETA: 2s - loss: 0.5149 - accuracy: 0.8156
146/391 [==========>...................] - ETA: 2s - loss: 0.5149 - accuracy: 0.8157
153/391 [==========>...................] - ETA: 2s - loss: 0.5139 - accuracy: 0.8161
159/391 [===========>..................] - ETA: 2s - loss: 0.5161 - accuracy: 0.8150
165/391 [===========>..................] - ETA: 2s - loss: 0.5156 - accuracy: 0.8154
171/391 [============>.................] - ETA: 2s - loss: 0.5135 - accuracy: 0.8162
177/391 [============>.................] - ETA: 2s - loss: 0.5148 - accuracy: 0.8158
183/391 [=============>................] - ETA: 2s - loss: 0.5155 - accuracy: 0.8155
189/391 [=============>................] - ETA: 2s - loss: 0.5171 - accuracy: 0.8147
195/391 [=============>................] - ETA: 2s - loss: 0.5189 - accuracy: 0.8140
201/391 [==============>...............] - ETA: 1s - loss: 0.5175 - accuracy: 0.8144
208/391 [==============>...............] - ETA: 1s - loss: 0.5165 - accuracy: 0.8144
214/391 [===============>..............] - ETA: 1s - loss: 0.5185 - accuracy: 0.8137
221/391 [===============>..............] - ETA: 1s - loss: 0.5182 - accuracy: 0.8140
228/391 [================>.............] - ETA: 1s - loss: 0.5175 - accuracy: 0.8143
234/391 [================>.............] - ETA: 1s - loss: 0.5170 - accuracy: 0.8144
240/391 [=================>............] - ETA: 1s - loss: 0.5161 - accuracy: 0.8150
246/391 [=================>............] - ETA: 1s - loss: 0.5168 - accuracy: 0.8142
253/391 [==================>...........] - ETA: 1s - loss: 0.5165 - accuracy: 0.8140
259/391 [==================>...........] - ETA: 1s - loss: 0.5169 - accuracy: 0.8136
265/391 [===================>..........] - ETA: 1s - loss: 0.5164 - accuracy: 0.8138
271/391 [===================>..........] - ETA: 1s - loss: 0.5161 - accuracy: 0.8139
278/391 [====================>.........] - ETA: 1s - loss: 0.5155 - accuracy: 0.8142
284/391 [====================>.........] - ETA: 1s - loss: 0.5156 - accuracy: 0.8140
291/391 [=====================>........] - ETA: 0s - loss: 0.5143 - accuracy: 0.8148
298/391 [=====================>........] - ETA: 0s - loss: 0.5146 - accuracy: 0.8146
305/391 [======================>.......] - ETA: 0s - loss: 0.5148 - accuracy: 0.8146
312/391 [======================>.......] - ETA: 0s - loss: 0.5151 - accuracy: 0.8144
325/391 [=======================>......] - ETA: 0s - loss: 0.5148 - accuracy: 0.8142
339/391 [=========================>....] - ETA: 0s - loss: 0.5161 - accuracy: 0.8137
354/391 [==========================>...] - ETA: 0s - loss: 0.5169 - accuracy: 0.8136
369/391 [===========================>..] - ETA: 0s - loss: 0.5193 - accuracy: 0.8127
383/391 [============================>.] - ETA: 0s - loss: 0.5197 - accuracy: 0.8126
391/391 [==============================] - 4s 10ms/step - loss: 0.5200 - accuracy: 0.8126 - val_loss: 2.0124 - val_accuracy: 0.5189
1/79 [..............................] - ETA: 0s - loss: 1.6155 - accuracy: 0.5625
10/79 [==>...........................] - ETA: 0s - loss: 1.8749 - accuracy: 0.5273
19/79 [======>.......................] - ETA: 0s - loss: 1.9776 - accuracy: 0.5169
27/79 [=========>....................] - ETA: 0s - loss: 1.9817 - accuracy: 0.5194
36/79 [============>.................] - ETA: 0s - loss: 1.9576 - accuracy: 0.5252
45/79 [================>.............] - ETA: 0s - loss: 1.9520 - accuracy: 0.5274
54/79 [===================>..........] - ETA: 0s - loss: 1.9581 - accuracy: 0.5268
63/79 [======================>.......] - ETA: 0s - loss: 1.9572 - accuracy: 0.5255
72/79 [==========================>...] - ETA: 0s - loss: 1.9786 - accuracy: 0.5215
79/79 [==============================] - 0s 6ms/step - loss: 2.0124 - accuracy: 0.5189
saved to ckpt/weights.ckpt
load weights from file
1/79 [..............................] - ETA: 5s - loss: 1.6155 - accuracy: 0.5625
10/79 [==>...........................] - ETA: 0s - loss: 1.8749 - accuracy: 0.5273
19/79 [======>.......................] - ETA: 0s - loss: 1.9776 - accuracy: 0.5169
28/79 [=========>....................] - ETA: 0s - loss: 1.9824 - accuracy: 0.5176
37/79 [=============>................] - ETA: 0s - loss: 1.9426 - accuracy: 0.5283
46/79 [================>.............] - ETA: 0s - loss: 1.9576 - accuracy: 0.5275
55/79 [===================>..........] - ETA: 0s - loss: 1.9660 - accuracy: 0.5259
64/79 [=======================>......] - ETA: 0s - loss: 1.9604 - accuracy: 0.5242
73/79 [==========================>...] - ETA: 0s - loss: 1.9814 - accuracy: 0.5205
79/79 [==============================] - 1s 7ms/step - loss: 2.0124 - accuracy: 0.5189
在不使用卷积神经网络的情况下,效果也就这样