如果你在本地无法完成模型训练,采用云主机完成是一个不错的选择。这里推荐 AWS,也就是亚马逊云服务。你可以登陆 https://aws.amazon.com
完成注册。然后新建一个 EC2
弹性云计算主机。
使用 AWS 的一个重要原因是它提供了许多优质的镜像,例如我选择的 ami-599a7721
社区镜像已经配置好了 Keras 的开发环境。除此之外,ami-638c1eo3
镜像也是不错的选择。启动之后,就无需再自行配置环境。
AWS 大部分主机都在国外,对于访问 GitHub 等外网资源和一些外网服务有着天然优势。
实例类型方面,推荐选择g2.2xlarge
。此款实例类型配置了1 个 NVIDIA GRID GPU (Kepler GK104)
以及来自 Intel Xeon E5-2670 的 8 个
硬件超线程提供支持。对于本次训练的规模来讲,已经足够了。标准配置的 g2.2xlarge
类型实例价格为 $0.65 (4.3 元人民币)每小时
。GPU 实例价格昂贵,按需使用完之后,务必注意要终止实例。
配置好实例之后,我们就可以开始训练模型了。g2.2xlarge
使用的 NVIDIA GRID GPU (Kepler GK104)
拥有 4 GB
显存,配置一般,比较经济。为了保险起见,我们可以验证一下作为后端的 tensorflow 是否已经成功切换到到 GPU。
# -*- coding:utf8 -*-
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from keras.preprocessing import image
from keras.models import load_model, Sequential
from keras.applications.inception_v3 import preprocess_input
# 载入模型
model = load_model('inceptionv3-tl.model')
# 预测函数
def predict(model, img):
img = img.resize((299, 299))
# 提取特征
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
# 预测
preds = model.predict(x)
return preds[0]
# 绘制图像
def plot_preds(image, preds):
plt.figure()
plt.subplot(211)
plt.imshow(image)
plt.axis('off')
plt.subplot(212)
labels = ("cat", "dog")
plt.barh([0, 1], preds, alpha=0.5)
plt.yticks([0, 1], labels)
plt.xlabel('Probability')
plt.xlim(0, 1.01)
plt.tight_layout()
plt.show()
# 运行
if __name__ == "__main__":
# 这里注意改到你放置测试图片的位置
img = Image.open('dog.jpg')
preds = predict(model, img)
print(preds)
plot_preds(img, preds)