tensorflow 神兵 contrib.learn自定义模型

实验环境

win10 64 bit python 3.5.3+tensorflow0.12.1+pandas0.19.2

代码

采用tensorflow官方tutorials程序abalone.py.代码中英文解释比较清楚了,不再赘述,有疑惑的童鞋可以留言。

#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""DNNRegressor with custom estimator for abalone dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tempfile
from six.moves import urllib

import numpy as np
import tensorflow as tf

flags = tf.app.flags
FLAGS = flags.FLAGS

flags.DEFINE_string(
    "train_data",
    "",
    "Path to the training data.")
flags.DEFINE_string(
    "test_data",
    "",
    "Path to the test data.")
flags.DEFINE_string(
    "predict_data",
    "",
    "Path to the prediction data.")

tf.logging.set_verbosity(tf.logging.INFO)

# Learning rate for the model
LEARNING_RATE = 0.001


def maybe_download():
  """Maybe downloads training data and returns train and test file names."""
  if FLAGS.train_data:
    train_file_name = FLAGS.train_data
  else:
    train_file = tempfile.NamedTemporaryFile(delete=False)
    urllib.request.urlretrieve("http://download.tensorflow.org/data/abalone_train.csv", train_file.name)  # pylint: disable=line-too-long
    train_file_name = train_file.name
    train_file.close()
    print("Training data is downloaded to %s" % train_file_name)

  if FLAGS.test_data:
    test_file_name = FLAGS.test_data
  else:
    test_file = tempfile.NamedTemporaryFile(delete=False)
    urllib.request.urlretrieve("http://download.tensorflow.org/data/abalone_test.csv", test_file.name)  # pylint: disable=line-too-long
    test_file_name = test_file.name
    test_file.close()
    print("Test data is downloaded to %s" % test_file_name)

  if FLAGS.predict_data:
    predict_file_name = FLAGS.predict_data
  else:
    predict_file = tempfile.NamedTemporaryFile(delete=False)
    urllib.request.urlretrieve("http://download.tensorflow.org/data/abalone_predict.csv", predict_file.name)  # pylint: disable=line-too-long
    predict_file_name = predict_file.name
    predict_file.close()
    print("Prediction data is downloaded to %s" % predict_file_name)

  return train_file_name, test_file_name, predict_file_name


# pylint: disable=unused-argument
def model_fn(features, targets, mode, params):
  """Model function for Estimator."""

  # Connect the first hidden layer to input layer
  # (features) with relu activation
  first_hidden_layer = tf.contrib.layers.relu(features, 10)

  # Connect the second hidden layer to first hidden layer with relu
  second_hidden_layer = tf.contrib.layers.relu(first_hidden_layer, 10)

  # Connect the output layer to second hidden layer (no activation fn)
  output_layer = tf.contrib.layers.linear(second_hidden_layer, 1)

  # Reshape output layer to 1-dim Tensor to return predictions
  predictions = tf.reshape(output_layer, [-1])
  predictions_dict = {"ages": predictions}

  # Calculate loss using mean squared error
  loss = tf.contrib.losses.mean_squared_error(predictions, targets)

  train_op = tf.contrib.layers.optimize_loss(
      loss=loss,
      global_step=tf.contrib.framework.get_global_step(),
      learning_rate=params["learning_rate"],
      optimizer="SGD")

  return predictions_dict, loss, train_op


def main(unused_argv):
  # Load datasets
  abalone_train, abalone_test, abalone_predict = maybe_download()

  # Training examples
  training_set = tf.contrib.learn.datasets.base.load_csv_without_header(
      filename=abalone_train,
      target_dtype=np.int,
      features_dtype=np.float64)

  # Test examples
  test_set = tf.contrib.learn.datasets.base.load_csv_without_header(
      filename=abalone_test,
      target_dtype=np.int,
      features_dtype=np.float64)

  # Set of 7 examples for which to predict abalone ages
  prediction_set = tf.contrib.learn.datasets.base.load_csv_without_header(
      filename=abalone_predict,
      target_dtype=np.int,
      features_dtype=np.float64)

  # Set model params
  model_params = {"learning_rate": LEARNING_RATE}

  # Build 2 layer fully connected DNN with 10, 10 units respectively.
  nn = tf.contrib.learn.Estimator(
      model_fn=model_fn, params=model_params)

  # Fit
  nn.fit(x=training_set.data, y=training_set.target, steps=5000)

  # Score accuracy
  ev = nn.evaluate(x=test_set.data, y=test_set.target, steps=1)
  loss_score = ev["loss"]
  print("Loss: %s" % loss_score)

  # Print out predictions
  predictions = nn.predict(x=prediction_set.data,
                           as_iterable=True)
  for i, p in enumerate(predictions):
    print("Prediction %s: %s" % (i + 1, p["ages"]))


if __name__ == "__main__":
  tf.app.run()

转载于:https://my.oschina.net/u/2276931/blog/830355

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
# 工程内容 这个程序是基于tensorflow的tflearn库实现部分RCNN功能。 # 开发环境 windows10 + python3.5 + tensorflow1.2 + tflearn + cv2 + scikit-learn # 数据集 采用17flowers据集, 官网下载:http://www.robots.ox.ac.uk/~vgg/data/flowers/17/ # 程序说明 1、setup.py---初始化路径 2、config.py---配置 3、tools.py---进度条和显示带框图像工具 4、train_alexnet.py---大数据集预训练Alexnet网络,140个epoch左右,bitch_size为64 5、preprocessing_RCNN.py---图像的处理(选择性搜索、数据存取等) 6、selectivesearch.py---选择性搜索源码 7、fine_tune_RCNN.py---小数据集微调Alexnet 8、RCNN_output.py---训练SVM并测试RCNN(测试的时候测试图片选择第7、16类中没有参与训练的,单朵的花效果好,因为训练用的都是单朵的) # 文件说明 1、train_list.txt---预训练数据,数据在17flowers文件夹中 2、fine_tune_list.txt---微调数据2flowers文件夹中 3、1.png---直接用选择性搜索的区域划分 4、2.png---通过RCNN后的区域划分 # 程序问题 1、由于数据集小的原因,在微调时候并没有像论文一样按一个bitch32个正样本,128个负样本输入,感觉正样本过少; 2、还没有懂最后是怎么给区域打分的,所有非极大值抑制集合canny算子没有进行,待续; 3、对选择的区域是直接进行缩放的; 4、由于数据集合论文采用不一样,但是微调和训练SVM时采用的IOU阈值一样,有待调参。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值