Tensorflow学习笔记6——全连接网络实践

第六讲 全连接网络实践

断点续训问题

接上一讲:如何实现断点续训,即怎样使反向传播每次从上一次结束时开始训练呢?

采用在mnist_backward.py的sess中添加判断ckpt是否存在的语句,即

ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH) #加载ckpt
if ckpt and ckpt.model_checkpoint_path: #如果ckpt存在
    saver.restore(sess,ckpt.model_checkpoint_path) #就将ckpt恢复到当前会话。

这三句话会实现给所有w和b赋保存在ckpt中的值,实现断点续训。

其他问题:

如何对输入的真实图片,输出预测结果?

在这里插入图片描述
可以将任务分成两个函数解决

def application():
testNum = input(“input the number of test pictures:”)
for I in range(testNum):
    testPic = raw_input(“the path of test picture:”)
    testPicArr = pre_pic(testPic) #先对手写数字图片做预处理,符合输入要求后
    preValue = restore_model(testPicArr) #未给复现的神经网络模型,输出预测值
    print”The prediction number is:”,preValue

这里使用的前向传播、反向传播和测试程序与前一个项目完全相同,多加一个mnist_app.py

几个特殊函数:

from PIL import Image
#打开图片
img=Image.open('path')
#显示图片
img.show()
#改变图片尺寸,单位是像素,Image.ANTIALIAS表示用消除锯齿的方法resize
out = img.resize((width, height),Image.ANTIALIAS) 
#convert函数将图片模式“RGB”转换为其他不同模式。模式”1”为二值图像,非黑即白。但是它每个像素用8个bit表示,0表示黑,255表示白。模式“L”表示为灰色图像,它的每个像素用8个bit表示,0表示黑,255表示白,其他数字表示不同的灰度。
out.convert(‘L’)  
#图片的保存 
img.save(‘path') 

np.multiply(x1,x2),作用:逐元素相乘,若x1和x2均为标量,则返回标量

如何制作数据集,实现特定应用?

制作数据集可以使用二进制文件tfrecords
tfrecords文是TensorFlow自带的一种二进制文件,可以先将图片和标签制作成该格式的文件。使用tfrecords进行数据读取,会提高内存利用率。tf.train.Example协议存储训练数据。训练数据的特征用键值对的形式表示。
如:’img_raw’:值 ‘label’:值 值的类型是Byteslist/FloatList/Int64List(字符串、实数列表、整数列表)
用SerializeToString()把数据序列化成字符串存储。

几个重要函数:

  • tf.train.BytesList()等:格式化原始数据可以使用tf.train.BytesList、tf.train.Int64List、tf.train.FloatList三个类。这三个类都有实例属性value用于我们将值传进去,一般tf.train.Int64List、tf.train.FloatList对应处理整数和浮点数,tf.train.BytesList用于处理其他类型的数据。

  • tf.train.Features():从名字来看,我们应该能猜出tf.train.Features是tf.train.Feature的复数,事实上tf.train.Features有属性为feature,这个属性的一般设置方法是传入一个字典,字典的key是字符串(feature名),而值是tf.train.Feature对象。
    例如:
    feature_dict = {
    “data_id”: tf.train.Feature(int64_list=data_id),
    “data”: tf.train.Feature(bytes_list=data)
    }
    features = tf.train.Features(feature=feature_dict)

  • tf.train.Example():tf.train.Example有一个属性为features,我们只需要将上一步得到的结果再次当做参数传进来即可。另外,tf.train.Example还有一个方法SerializeToString()需要说一下,这个方法的作用是把tf.train.Example对象序列化为字符串,因为我们写入文件的时候不能直接处理对象,需要将其转化为字符串才能处理。
    例如: example = tf.train.Example(features=features)
    example_str = example.SerializeToString()

  • string_input_producer():输出字符串到一个输入管道队列。

  • parse_single_example(
    serialized,
    features,
    name=None,
    example_names=None ) serialized:一个标量字符串张量,单个序列化的例子。 features:一个 dict,映射功能键到 FixedLenFeature 或 VarLenFeature值。 name:此操作的名称(可选)。
    example_names:(可选)标量字符串张量,关联的名称 用于解序列化。

  • tf.FixedLenFeature这种方法解析的结果为一个 Tensor,tf .VarLenFeature这种方法得到的解析结果为SparseTensor ,用于处理稀疏数据。

代码

生成tfrecords文件:

writer = tf.python_io.TFRecordWriter(tfRecordName)#新建一个writer
for 循环遍历每张图和标签:
    example = tf.train.Example(features = tf.train.Features(feature={
        'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw])),
        'label':tf.train.Feature(int64_list = tf.train.Int64List(value=labels))
    }))#把每张图片和标签封装到example中
    writer.write(example.SerializeToString())#把example进行序列化
writer.close()

解析tfrecords文件:

def read_tfRecord(tfRecord_path):
    filename_queue = tf.train.string_input_producer([tfRecord_path])
    reader = tf.TFRecordReader() #新建一个reader
    _,serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,features={
        'img_raw':tf.FixedLenFeature([],tf.string),
        'label':tf.FixedLenFeature([10],tf.int64)
    })
    img = tf.decode_raw(features['img_raw'],tf.uint8)
    img.set_shape([784])
    img = tf.cast(img,tf.float32)*(1./225)
    label = tf.cast(features['label'],tf.float32)
    return img,label

以上学习内容来自中国MOOC网课程:https://www.icourse163.org/course/PKU-1002536002

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值