网络模型的交叉验证
数据集的内容一般分为:trainset、valset和testset,将获取的mnist数据集的(x,y)部分拆分为[x_train, y_train](训练数据)和[x_val, y_val](评测数据)
#把[x, y]拆分成[x_train, y_train]和[x_val, y_val]
idx = tf.range(60000)
idx = tf.random.shuffle(idx)
x_train, y_train = tf.gather(x, idx[:50000]), tf.gather(y, idx[:50000]) #[0, 50000]
x_val, y_val = tf.gather(x, idx[-10000:]) , tf.gather(y, idx[-10000:]) #[50000,60000]
print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)
用这两部分数据集进行边训练,边测评,能够查看当前训练的指标效果。可以查看每次训练后评价指标,找出最好的参数。
#一共training6次,每2次评测一次
# db_train:training数据集, epochs:数据集的training的次数, validation_data:用于评测的数据集, validation_freq:每n个epoch做一次评测
network.fit(db_train, epochs=6, validation_data=db_val, validation_freq=2)
用(x_test, y_test)数据集对模型进行进行测试。
network.evaluate(db_test)
完整代码:
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
def preprocess(x, y):
"""
x is a simple image, not a batch
"""
x = tf.cast(x, dtype=tf.float32) / 255.
x = tf.reshape(x, [28*28])
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x,y
batchsz = 128
(x, y), (x_test, y_test) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())
#把[x, y]拆分成[x_train, y_train]和[x_val, y_val]
idx = tf.range(60000)
idx = tf.random.shuffle(idx)
x_train, y_train = tf.gather(x, idx[:50000]), tf.gather(y, idx[:50000]) #[0, 50000]
x_val, y_val = tf.gather(x, idx[-10000:]) , tf.gather(y, idx[-10000:]) #[50000,60000]
print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)
db_train = tf.data.Dataset.from_tensor_slices((x_train,y_train))
db_train = db_train.map(preprocess).shuffle(50000).batch(batchsz)
db_val = tf.data.Dataset.from_tensor_slices((x_val,y_val))
db_val = db_val.map(preprocess).shuffle(10000).batch(batchsz)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.map(preprocess).batch(batchsz)
sample = next(iter(db_train))
print(sample[0].shape, sample[1].shape)
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.build(input_shape=(None, 28*28))
network.summary()
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
#一共training6次,每2次评测一次
# db_train:training数据集, epochs:数据集的training的次数, validation_data:用于评测的数据集, validation_freq:每n个epoch做一次评测
network.fit(db_train, epochs=6, validation_data=db_val, validation_freq=2)
#测试的数据集
print('Test performance:')
network.evaluate(db_test)
sample = next(iter(db_test))
x = sample[0]
y = sample[1] # one-hot
pred = network.predict(x) # [b, 10]
# convert back to number
y = tf.argmax(y, axis=1)
pred = tf.argmax(pred, axis=1)
print(pred)
print(y)