import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')##设置GPU调用策略
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu,True)
##t图像预处理
def preprocess(data):
data = tf.expand_dims(data,-1)
data = data/255
data = tf.cast(data,tf.float32)
return data
##将数据转换成dataset形式
(train_image,train_label),(test_image,test_label) = tf.keras.datasets.mnist.load_data()
train_image = preprocess(train_image)
test_image = preprocess(test_image)
train_ds = tf.data.Dataset.from_tensor_slices((train_image,train_label))
test_ds = tf.data.Dataset.from_tensor_slices((test_image,test_label))
train_ds = train_ds.shuffle(len(train_image)).batch(32)
test_ds = test_ds.batch(32)
##建立网络
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(32,(3,3),input_shape=(None,None,1),padding='same',activation='relu'))
model.add(tf.keras.layers.Conv2D(64,(3,3),padding='same',activation='relu'))
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dense(10,activation='softmax'))
##配置优化器和损失函数
optimizer = tf.keras.optimizers.Adam()
loss = tf.keras.losses.SparseCategoricalCrossentropy()
##损失值和精确值初始化
train_loss = tf.keras.metrics.Mean()
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
test_loss = tf.keras.metrics.Mean()
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
##每一次训练包含前向传播和反向传播
def train_step(model,image,label):
with tf.GradientTape() as t:##记录中间梯度值
pred = model(image)
loss_step = loss(label,pred)##计算损失值
grads = t.gradient(loss_step,model.trainable_variables)##计算梯度
optimizer.apply_gradients(zip(grads,model.trainable_variables))##由梯度更新参数
train_loss(loss_step)
train_accuracy(label,pred)
def test_step(model,image,label):
pred = model(image)
loss_step = loss(label,pred)
test_loss(loss_step)
test_accuracy(label,pred)
##按epoch训练数据
def train():
for epoch in range(10):
for batch,(image,label) in enumerate(train_ds):
train_step(model,image,label)
for batch,(image,label) in enumerate(test_ds):
test_step(model,image,label)
print('epoch={},
loss={},
accuracy={},
val_loss={},
val_accuracy={}'.format(epoch,
train_loss.result(),
train_accuracy.result(),
test_loss.result(),
test_accuracy.result()))
train_loss.reset_states()
train_accuracy.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()