tf.keras是tensorflow目前最新的一个高阶API,目前tf的官网都是以此为指南来教大家入门tensorflow
这个API的使用和此前需要session不同,用这个API搭建de两种方式:
- tf.keras.Sequential:一层一层得顺序搭建一个模型
- tf.keras.Model:函数API,可以创建多输入、多输出模型
这里的笔记以tf.keras.Sequential为示例
初始化模型
#导入相应的库
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
#以mnist数据集为例
import keras.datasets.minist as mnist
(train_image,train_lable),(test_image,test_lable) = mnist.load_data()
#查看数据的shape
train_image.shape
import matplotlib.pyplot as plt
#将图片显示在当前页面
%matplotlib inline
plt.imshow(train_image[3])
#把图片展平,因为这里我们用tf.keras.Sequential这种方式来建立模型
train_image = train_image.reshape(-1,28*28)
#正式开始建立模型
model = keras.Sequential()
添加层,构建网络
#网络的第一层要给定输入数据的shape
model.add(layers.Conv2D(64,(3,3),strides=(1,1),activation='relu',input_shape=(784,)))
model.add(layers.MaxPooling2D((2,2)))