CIFAR-10数据集由10个类的60000个32x32彩色图像组成,每个类有6000个图像,有50000个训练图像和10000个测试图像。
以下代码实现基于tensorflow对此数据集的训练和验证,采用的是自定义的网络,并把相应的参数(权重)保存起来
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
import os
# 避免出现一些不必要的警告
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
def preprocess(x, y):
# [0,255] => [-1 ,1]
x = 2 *tf.cast(x, dtype=tf.float32) / 255. - 1.
y = tf.cast(y, dtype=tf.int32)
return x, y
batchsz = 128
# [32, 32, 3], [10k, 1]
(x, y), (x_val, y_val) = datasets.cifar10.load_data()
y = tf.squeeze(y) # [10k]
y_val = tf.squeeze(y_val)
y = tf.one_hot(y, depth=10) # [50k,10]
y_val = tf.one_hot(y_val, depth=10)