1 keras安装
主要参考keras中文文档中keras安装和配置指南http://keras-cn.readthedocs.io/en/latest/for_beginners/keras_linux/
keras安装好后backend默认使用的是TensorFlow,要切换为theano的话,先要使用sudo pip install -U --pre theano
下载theano,然后修改.keras文件中的keras.json配置文件。
方法:
1. 在ubuntu终端输入find / -name .keras
,找到.keras文件夹所在位置(我的是在/root
目录下和当前普通用户主目录下都有)并切换到该目录下。
2. 输入gedit keras.json
,将最后一行的tensorflow替换成theano,保存并退出。
3. 重新在ubuntu终端输入python命令,接着输入import keras
后会显示using theano backend
。
2 mnist测试
- 下载mnist数据集http://yann.lecun.com/exdb/mnist/。
- 建立一个mnist文件夹,将下载的mnist数据集复制到mnist文件夹下。
- 在与mnist文件夹同一级目录下利用
gedit mnist_test.py
新建一个mnist_test.py,将下面的代码复制进去并保存。 - 直接用python运行该文件即可
python mnist_test.py
(在theano的backend下会出现警告,不过不影响运行)。
#mnist_test.py
import numpy as np
import gzip
import struct
import keras as ks
import logging
from keras.layers import Dense, Activation, Flatten, Convolution2D
from keras.utils import np_utils
def read_data(label_url,image_url):
with gzip.open(label_url) as flbl:
magic, num = struct.unpack(">II",flbl.read(8))
label = np.fromstring(flbl.read(),dtype=np.int8)
with gzip.open(image_url,'rb') as fimg:
magic, num, rows, cols = struct.unpack(">IIII",fimg.read(16))
image = np.fromstring(fimg.read(),dtype=np.uint8).reshape(len(label),rows,cols)
return (label, image)
(train_lbl, train_img) = read_data('mnist/train-labels-idx1-ubyte.gz','mnist/train-images-idx3-ubyte.gz')
(val_lbl, val_img) = read_data('mnist/t10k-labels-idx1-ubyte.gz','mnist/t10k-images-idx3-ubyte.gz')
def to4d(img):
return img.reshape(img.shape[0],784).astype(np.float32)/255
train_img = to4d(train_img)
val_img = to4d(val_img)
train_LBL = np_utils.to_categorical(train_lbl)
val_LBL = np_utils.to_categorical(val_lbl)
model = ks.models.Sequential()
model.add(Dense(128,input_dim=784))
model.add(Activation('relu'))
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dense(10))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy',optimizer='adadelta',metrics=['accuracy'])
model.fit(x=train_img,y=train_LBL,batch_size=100,epochs=10,verbose=1,validation_data=(val_img,val_LBL))
成功测试mnist后的结果如下: