tensorflow 2.0.0-alpha0版本tf.keras搭建分类模型–识别fashion_mnist代码
首先导入必要的库。
注意:
tensorflow
的版本是否为2.0.0-alpha0
numpy
的版本不要太高,
import tensorflow as tf
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import pandas as pd
from tensorflow import keras
print(tf.__version__)
for module in mpl, np, pd, tf, keras:
print(module.__name__, module.__version__)
第二部导入数据集。
keras
里收集了一些简单的数据提供于练习。
导入fashion_mnist
数据集,将fashion_mnist
数据集拆分成训练集*_train_all
和测试集*_test
再将训练集*_train_all
拆分成训练集*_train
和验证集*_valid
。
fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]