以前学的caffe,caffe不够灵活,所以开始转学TensorFlow了,然后看着官网的教程,跑了把入门级别的mnist ,10分类。
发现以前学的caffe还是有帮助的,caffe有很多可视化的工具,帮助你看到整个net的详细结果,训练过程还有digis
可视话整个训练过程,所以选择caffe入门确实很好。等有了点深层学习的基础理论,选择再转TensorFlow,可能接受能力
会强一些。毕竟它们具体实现不同,但是都是围绕深层学习理论来做的的框架,它们的整体模块都是相似的,只是具体细节
表述不同而已,如果一样还谈什么2种框架嘛。
整体上都是:
1 。先对训练数据处理成框架能够输入的格式。
2. 定义一个net结构。net结构要绑定数据源文件。
3. 整个net的定义参数,也要绑定net结构文件。包括学习率啊,训练多少批,采用什么梯度下降,GPU还是CPU模式等。
4. 训练模型,训练可以保存训练模型。
5 。测试模型性能,可以读取训练好的模型。
下面开始讲mnist的训练,整个过程并不真正的lenet,而是一个简单的2层神经网络来训练和测试。
整个模块分2部分,input_data.py和mnist_test.py。
input_data.py下载用于训练和测试的MNIST数据集的源码,并且转为训练需要的二进制数据。
mnist_test.py 定义了网络结构,训练,测试的代码。
input_data.py代码
# ==============================================================================
"""Functions for downloading and reading MNIST data."""
#在 input_data.py 文件中,下载用于训练和测试的MNIST数据集的源码
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import tensorflow.python.platform
import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
#===========================#===========================
#下载数据到 指定的本地文件夹中,返回该文件的完整路径
# MNIST_test.py 文件的顶部由一个标记变量指定,你可以根据自己的需要进行修改。
#传入filename图片名字:1。 补充SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'下载网址用, 2。 指定保存文件的具体名字。work_directory将保存图片的路径
def maybe_download(filename, work_directory):
"""Download the data from Yann's website, unless it's already here."""
if not os.path.exists(work_directory):
os.mkdir(work_directory)
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
'''
直接将远程数据下载到本地。
urllib.urlretrieve(url[, filename[, reporthook[, data]]])
参数说明:
url:外部或者本地url
filename:指定了保存到本地的路径(如果未指定该参数,urllib会生成一个临时文件来保存数据);
reporthook:是一个回调函数,当连接上服务器、以及相应的数据块传输完毕的时候会触发该回调。我们可以利用这个回调函数来显示当前的下载进度。
data:指post到服务器的数据。该方法返回一个包含两个元素的元组(filename, headers),filename表示保存到本地的路径,header表示服务器的响应头。
'''
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
'''
os.stat() 方法用于在给定的路径上执行一个系统 stat 的调用。
stat()方法语法格式如下:
os.stat(path)
参数
path -- 指定路径
返回值:st_mode: inode 保护模式
st_ino: inode 节点号。
st_dev: inode 驻留的设备。
st_nlink: inode 的链接数。
st_uid: 所有者的用户ID。
st_gid: 所有者的组ID。
st_size: 普通文件以字节为单位的大小;包含等待某些特殊文件的数据。
st_atime: 上次访问的时间。
st_mtime: 最后一次修改的时间。
st_ctime: 由操作系统报告的"ctime"。在某些系统上(如Unix)是最新的元数据更改的时间,
例子:
# 显示文件 "a2.py" 信息
statinfo = os.stat('a2.py')
print statinfo
执行以上程序输出结果为:
posix.stat_result(st_mode=33188, st_ino=3940649674337682L, st_dev=277923425L, st
_nlink=1, st_uid=400, st_gid=401, st_size=335L, st_atime=1330498089, st_mtime=13
30498089, st_ctime=1330498089)
'''
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
return filepath
#=======================