1. 代码
// An highlighted block
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,datasets
(x_train,y_train),(x_test,y_test) = datasets.mnist.load_data()
x_train = x_train.reshape(60000,28,28,1)
x_test = x_test.reshape(10000,28,28,1)
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
def preprocess(x, y):
x = tf.cast(x, dtype=tf.float32) / 255.
y = tf.cast(y, dtype=tf.int32)
return x,y
train_db = tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_db = train_db.shuffle(1000).map(preprocess).batch(128)
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = test_db.map(preprocess).batch(128)
sample = next(iter(train_db))
print('sample:', sample[0].shape, sample[1].shape,
tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))
import numpy as np
np.max(sample[0][0].numpy())
class MyConv(layers.Layer):
def __init__(self):
super(MyConv,self).__init__()
self.conv1 = layers.Conv2D(64,(5,5),strides=2,padding='SAME') [None,14,14,96]
self.relu1 = layers.Activation('relu')
self.pool1 = layers.MaxPool2D(pool_size=(2,2),strides=2) #[None,6,6,96]
self.bn1 = layers.BatchNormalization()
self.conv2 = layers.Conv2D(32,(3,3),strides=2,padding='SAME') #[None,4,4,None]
self.relu2 = layers.Activation('relu')
self.pool2 = layers.MaxPool2D(pool_size=(2,2),strides=1) #[]
self.bn2 = layers.BatchNormalization()
def call(self,inputs):
out = self.conv1(inputs)
out = self.relu1(out)
out = self.pool1(out)
out = self.bn1(out)
out = self.conv2(out)
out = self.relu2(out)
out = self.pool2(out)
out = self.bn2(out)
return out
class MyModel(keras.Model):
def __init__(self):
super(MyModel,self).__init__()
self.Conv = MyConv()
#self.dense1 = layers.Dense(1024)
#self.dense2 = layers.Dense(128)
self.dense3 = layers.Dense(10)
def call(self,inputs):
x = self.Conv(inputs)
x = tf.reshape(x,[-1,3*3*32])
#x = self.dense1(x)
#x = self.dense2(x)
x = self.dense3(x)
return x
model = MyModel()
model.build(input_shape=(None, 28, 28, 1))
model.summary()
optimizer = optimizers.Adam(lr=1e-3)
//for step, (x,y) in enumerate(train_db):
// break
//x = tf.cast(x,dtype=tf.float32)
//with tf.device('/cpu:0'):
// a = model(x)
//a.shape
with tf.device('/cpu:0'):
for epoch in range(50):
for step,(x,y) in enumerate(train_db):
with tf.GradientTape() as tape:
logits = model(x)
logits = tf.nn.softmax(logits)
y_onehot = tf.one_hot(y,depth=10)
loss = tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=False)
loss = tf.reduce_mean(loss)
grads = tape.gradient(loss,model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
if step %50 == 0:
print(epoch, step, 'loss:', float(loss))
total_num = 0
total_correct = 0
for x,y in test_db:
logits = model(x)
logits = tf.reshape(logits,[-1,10])
prob = tf.nn.softmax(logits,axis=1)
pred = tf.argmax(prob,axis=1)
pred = tf.cast(pred, dtype=tf.int32)
correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)
correct = tf.reduce_sum(correct)
total_num += x.shape[0]
total_correct += int(correct)
acc = total_correct / total_num
print(epoch, 'acc:', acc)
2. 事项
跑了两三个epoch大概测试集98%,代码很乱。