供个人学习记录,来源于:
https://github.com/machinelearningmindset/TensorFlow-Course#why-use-tensorflow
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import urllib #操作URL
import tempfile #临时文件和目录的处理
import pandas as pd #解决数据分析任务
from tensorflow.examples.tutorials.mnist import input_data
max_num_checkpoint = 10
num_classes = 2
batch_size = 512
num_epochs = 10
initial_learning_rate = 0.001 #学习率
learning_rate_decay_factor = 0.95 #衰减率
num_epochs_per_decay = 1 #使用次数
is_training = False
fine_tuning = False
online_test = True
allow_soft_placement = True
log_device_placement = False
mnist = input_data.read_data_sets("MNIST_data/", reshape=True, one_hot=False) #读取MNIST数据集
data={}
data['train/image'] = mnist.train.images
data['train/label'] = mnist.train.labels
data['test/image'] = mnist.test.images
data['test/label'] = mnist.test.labels
def extract_samples_Fn(data): #选取0类和1类
index_list = []
for sample_index in range(data.shape[0]):
label = data[sample_index]
if label == 1 or label == 0:
index_list.append(sample_index)
return index_list
index_list_train = extract_samples_Fn(data['train/label'])
index_list_test = extract_samples_Fn(data['test/label'])
data['train/image'] = mnist.train.images[index_list_train]
data['train/label'] = mnist.train.labels[index_list_train]
data['test/image'] = mnist.test.images[index_list_test]
data['test/label'] = mnist.test.labels[index_list_test]
dimensionality_train = data['train/image'].shape
num_train_samples = dimensionality_train[0] #训练数据量
num_features = dimensionality_train[1] #训练图片大小
graph = tf.Graph() #实例化一个类
with graph.as_default(): #创建一张默认图
global_step = tf.Variable(0, name="global_step", trainable=False) #用于衰减的全局步骤
decay_steps = int(num_train_samples/batch_size*num_epochs_per_decay) #衰减速度,每隔**衰减一次
learning_rate = tf.train.exponential_decay(initial_learning_rate,global_step,decay_steps,learning_rate_decay_factor,staircase=True,name='exponential_decay_learning_rate') #指数梯度下降法
image_place = tf.placeholder(tf.float32, shape=([None, num_features]), name='image')
label_place = tf.placeholder(tf.int32, shape=([None,]), name='gt')
label_one_hot = tf.one_hot(label_place, depth=num_classes, axis=-1) #label进行one_hot编码
dropout_param = tf.placeholder(tf.float32) #dropout
logits = tf.contrib.layers.fully_connected(inputs=image_place, num_outputs = num_classes, scope='fc') #定义全连接层
with tf.name_scope('loss'):
loss_tensor = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=label_one_hot)) #求交叉熵
prediction_correct = tf.equal(tf.argmax(logits, 1), tf.argmax(label_one_hot, 1)) #行寻找最大值比较
accuracy = tf.reduce_mean(tf.cast(prediction_correct, tf.float32)) #转换数据类型求均值
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) #梯度优化
with tf.name_scope('train_op'):
gradients_and_variables = optimizer.compute_gradients(loss_tensor) #计算损失函数对变量的梯度
train_op = optimizer.apply_gradients(gradients_and_variables, global_step=global_step) #变量梯度更新
session_conf = tf.ConfigProto(allow_soft_placement=allow_soft_placement,log_device_placement=log_device_placement) #session参数配置
sess = tf.Session(graph=graph, config=session_conf)
with sess.as_default(): #创建一个默认会话
saver = tf.train.Saver() #参数保存
sess.run(tf.global_variables_initializer())
checkpoint_prefix = 'model'
if fine_tuning:
saver.restore(sess, os.path.join(checkpoint_path, checkpoint_prefix)) #提取训练好的参数
print("Model restored for fine-tuning...")
test_accuracy = 0
for epoch in range(num_epochs):
total_batch_training = int(data['train/image'].shape[0] / batch_size)
for batch_num in range(total_batch_training):
start_idx = batch_num * batch_size
end_idx = (batch_num + 1) * batch_size
train_batch_data, train_batch_label = data['train/image'][start_idx:end_idx], data['train/label'][start_idx:end_idx]
batch_loss, _, training_step = sess.run([loss_tensor, train_op, global_step],feed_dict={image_place:train_batch_data,label_place:train_batch_label,dropout_param:0.5})
print("Epoch " + str(epoch + 1) + ", Training Loss= " + "{:.5f}".format(batch_loss))
test_accuracy = 100 * sess.run(accuracy, feed_dict={image_place: data['test/image'],label_place: data['test/label'],dropout_param: 1.})
print("Final Test Accuracy is %% %.2f" % test_accuracy)