#coding=utf-8
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def get_bact(num, x_train, y_train): #打乱顺序,抽取batch
examples_num = np.arange(len(x_train))
np.random.shuffle(examples_num)
examples_num = examples_num[0:num]
x_data = []
y_lable = []
for i in examples_num:
x_data.append(x_train[i])
y_lable.append(y_train[i])
return np.asarray(x_data),np.asarray(y_lable)
if __name__ == "__main__":
x_train = np.random.random((100, 1))
y_train = 0.3*x_train+0.1
w = tf.Variable(tf.zeros(1), name="W")
basi = tf.Variable(tf.zeros(1), name="basi")
x = tf.placeholder(tf.float32, shape=[None, 1],name="input_data")
y_ = tf.placeholder(tf.float32, shape=[None, 1])
y = tf.add(tf.multiply(w, x), basi,name= "y") #预测
cross_entry = tf.reduce_mean(tf.square(y - y_)) #平方误差
optimize = tf.train.GradientDescentOptimizer(0.001).minimize(cross_entry)#最小化平方误差
validation_feed = {x: x_validation, y_: y_validation}
init_op = tf.global_variables_initializer() #初始化所有变量
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y, 1)) # 正确值的一个返回
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) # 求平均正确率
Save = tf.train.Saver()#定义Save
with tf.Session() as sess:
sess.run(init_op)
x_plot = []
y_plot = []
for i in range(40000):
xx, yy = get_bact(90, x_train, y_train)
_, ww, bbasi, xxx = sess.run([optimize, w, basi, cross_entry], feed_dict={x: xx, y_: yy})
print("ww:%f,bais:%floss:%f" % (ww, bbasi, xxx))
if i % 1000== 0:
x_plot.append(i)
y_plot.append(xxx)
if xxx <= 0.00006:
Save.save(sess, "model/model.ckpt")#保存模型
break
plt.xlabel("epoch")
plt.ylabel("loss")
plt.plot(x_plot, y_plot)
plt.show()
对应的使用保存的模型进行测试:
#利用训练好的模型进行测试
import tensorflow as tf
saver = tf.train.import_meta_graph('./model/model.ckpt.meta') #加载图
sess = tf.Session()
saver.restore(sess,tf.train.latest_checkpoint('./model/'))#获得最新的数据,输入是模型的路径
print(sess.run("W:0")) #打印出权重值
print(sess.run("basi:0")) #偏置
graph = tf.get_default_graph() #加载默认图
x = graph.get_tensor_by_name("input_data:0") #通过名字,获得张量
y = graph.get_tensor_by_name("y:0")
feed = {x:[[2.0]]} #很重要,当输入为[None,1]的时候,不是单纯的x:2.0或者
print(sess.run(y,feed_dict=feed)) #测试,获取结果
注意:在训练的时候最好把每个张量的名字带上,以免后期保存模型后无法找到对应的张量