TensorFlow训练mnist分析

以前学的caffe,caffe不够灵活,所以开始转学TensorFlow了,然后看着官网的教程,跑了把入门级别的mnist ,10分类。

 

发现以前学的caffe还是有帮助的,caffe有很多可视化的工具,帮助你看到整个net的详细结果,训练过程还有digis

可视话整个训练过程,所以选择caffe入门确实很好。等有了点深层学习的基础理论,选择再转TensorFlow,可能接受能力

会强一些。毕竟它们具体实现不同,但是都是围绕深层学习理论来做的的框架,它们的整体模块都是相似的,只是具体细节

表述不同而已,如果一样还谈什么2种框架嘛。

 

整体上都是:

1 。先对训练数据处理成框架能够输入的格式。

2.  定义一个net结构。net结构要绑定数据源文件。

3.  整个net的定义参数,也要绑定net结构文件。包括学习率啊,训练多少批,采用什么梯度下降,GPU还是CPU模式等。

4.  训练模型,训练可以保存训练模型。

5  。测试模型性能,可以读取训练好的模型。

 

下面开始讲mnist的训练,整个过程并不真正的lenet,而是一个简单的2层神经网络来训练和测试。

整个模块分2部分,input_data.py和mnist_test.py。

input_data.py下载用于训练和测试的MNIST数据集的源码,并且转为训练需要的二进制数据。

mnist_test.py 定义了网络结构,训练,测试的代码。

 

 

input_data.py代码

# ==============================================================================
"""Functions for downloading and reading MNIST data."""
#在 input_data.py 文件中,下载用于训练和测试的MNIST数据集的源码
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import tensorflow.python.platform
import numpy
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'

#===========================#===========================

#下载数据到 指定的本地文件夹中,返回该文件的完整路径
# MNIST_test.py 文件的顶部由一个标记变量指定,你可以根据自己的需要进行修改。
#传入filename图片名字:1。 补充SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'下载网址用, 2。 指定保存文件的具体名字。work_directory将保存图片的路径
def maybe_download(filename, work_directory):
  """Download the data from Yann's website, unless it's already here."""
  if not os.path.exists(work_directory):
    os.mkdir(work_directory)
    filepath = os.path.join(work_directory, filename)
    
  if not os.path.exists(filepath):
     '''
      直接将远程数据下载到本地。 
urllib.urlretrieve(url[, filename[, reporthook[, data]]]) 
参数说明: 
url:外部或者本地url 
filename:指定了保存到本地的路径(如果未指定该参数,urllib会生成一个临时文件来保存数据); 
reporthook:是一个回调函数,当连接上服务器、以及相应的数据块传输完毕的时候会触发该回调。我们可以利用这个回调函数来显示当前的下载进度。 
data:指post到服务器的数据。该方法返回一个包含两个元素的元组(filename, headers),filename表示保存到本地的路径,header表示服务器的响应头。 
     '''
    filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
    '''
    os.stat() 方法用于在给定的路径上执行一个系统 stat 的调用。
stat()方法语法格式如下:
os.stat(path)
参数
path -- 指定路径
返回值:st_mode: inode 保护模式
st_ino: inode 节点号。
st_dev: inode 驻留的设备。
st_nlink: inode 的链接数。
st_uid: 所有者的用户ID。
st_gid: 所有者的组ID。
st_size: 普通文件以字节为单位的大小;包含等待某些特殊文件的数据。
st_atime: 上次访问的时间。
st_mtime: 最后一次修改的时间。
st_ctime: 由操作系统报告的"ctime"。在某些系统上(如Unix)是最新的元数据更改的时间,

例子:
# 显示文件 "a2.py" 信息
statinfo = os.stat('a2.py')
print statinfo
执行以上程序输出结果为:
posix.stat_result(st_mode=33188, st_ino=3940649674337682L, st_dev=277923425L, st
_nlink=1, st_uid=400, st_gid=401, st_size=335L, st_atime=1330498089, st_mtime=13
30498089, st_ctime=1330498089)
    '''
    statinfo = os.stat(filepath)
    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
 
    return filepath

#=======================
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值