MNIST是一个手写体数字识别的神经网络。具体的介绍请自行百度或者别的搜索引擎。
下面主要讲如何用tensorflow训练出一个神经网络,然后pytorch用训练好的网络参数来做预测。
1. 修改数据源位置
首先把MNIST源修改为本地数据源,否则可能会卡在下载数据源这一步。
#origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
#path = get_file(
# path,
# origin=origin_folder + 'mnist.npz',
# file_hash=
# '731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1')
path='G:/mnist/data/mnist.npz' #这里是本地数据源的位置
2. 训练集和测试集定义,用于接下来的训练
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
3. 定义网络结构(tensorflow)
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(64, activation='tanh'),
keras.layers.Dropout(.2),
keras.layers.Dense(128, activation='sigmoid'),
keras.layers.Dropout(.2),
keras.layers.Dense(10, activation='softmax')
])
这是一个非常简单的