TensorFlow2训练数据集的两种方式

方式一:

def pre_process(x, y):
    x = 2. * tf.cast(x, dtype=tf.float32) / 255. - 1.
    y = tf.cast(y, dtype=tf.int32)
    return x, y


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
x_train, y_train = pre_process(x_train, y_train)
x_test, y_test = pre_process(x_test, y_test)
print(x_train.shape, y_train.shape)

history = net.fit(x_train, y_train,
                  batch_size=512,
                  epochs=100,
                  validation_split=0.2)

test_scores = net.evaluate(x_test, y_test, verbose=2)

训练方式二:

def pre_process(x, y):
    # [0,255] => [-1,1] ,[-1,1]可能是一个最适合神经网络计算的范围
    x = 2. * tf.cast(x, dtype=tf.float32) / 255. - 1
    y = tf.squeeze(y)  # 从张量形状中移除大小为1的维度.
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y


batch_size = 128
(x, y), (x_val, y_val) = datasets.cifar10.load_data()
print('datasets:', x.shape, y.shape, x.min(), y.min())

train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.map(pre_process).shuffle(1000).batch(batch_size)

test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_db = test_db.map(pre_process).shuffle(1000).batch(batch_size)

sample = next(iter(train_db))
print('batch:', sample[0].shape, sample[1].shape)

network = MyNetwork()  # MYNetwork是Keras.Model的一个子类
network.compile(
    optimizer=optimizers.Adam(learning_rate=1e-3),
    loss=tf.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
network.fit(train_db, epochs=50, validation_data=test_db, validation_freq=1)
network.evaluate(test_db)
network.save_weights('./ckpt/cifar10_weights.ckpt') # b将模型保存到磁盘文件

参考链接:

1.李沐大神《动手深度学习》TensorFlow实现,GitHub链接:https://github.com/TrickyGo/Dive-into-DL-TensorFlow2.0,参考了其中的CNN5.9GoogleNet部分代码

2龙良曲.深度学习与TensorFlow入门实战,项目GitHub链接:https://github.com/dragen1860/Deep-Learning-with-TensorFlow-book,摘自其中的Lesson40--CIFAR与VGG实战

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值