Tensorflow的模型加载

当模型训练后,可以将模型保存,会生成得到以下文件:

要把这些文件中所代表的参数加载进来,本文所介绍的方法是:
在加载文件中所定义的变量要和训练模型中的变量相同,若不相同,在加载模型文件时就会报错

下面的代码是我练习时加载的文件,定义的权值和偏置都和训练时完全相同,然后再加载文件。

import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn import preprocessing

data = np.loadtxt("total.txt", dtype=np.float32)
test = np.loadtxt("test.txt", dtype=np.float32)
input = data[:, 0:8]
output = data[:, 8:9]

train_input, test_input, train_output, test_output = train_test_split(input, output, train_size=0.99, random_state=33)

mimmaxTransform = preprocessing.MinMaxScaler(feature_range=(0, 1))
train_input = mimmaxTransform.fit_transform(train_input)
test = mimmaxTransform.transform(test)

x = tf.placeholder(tf.float32, [None, 8])
y = tf.placeholder(tf.float32, [None, 1])
keep_prob = 1

#构建第一层网络
W1 = tf.Variable(tf.truncated_normal([8,25], stddev=0.1))
b1 = tf.Variable(tf.zeros([1, 25]) + 0.1)
L1 = tf.nn.relu(tf.matmul(x, W1) + b1)
L1_drop = tf.nn.dropout(L1, keep_prob)

#构建第二层网络
W2 = tf.Variable(tf.truncated_normal([25, 30], stddev=0.1))
b2 = tf.Variable(tf.zeros([1, 30]) + 0.1)
L2 = tf.nn.relu(tf.matmul(L1_drop, W2) + b2)
L2_drop = tf.nn.dropout(L2, keep_prob)

W3 = tf.Variable(tf.truncated_normal([30, 24], stddev=0.1))
b3 = tf.Variable(tf.zeros([1, 24]) + 0.1)
L3 = tf.nn.relu(tf.matmul(L2_drop, W3) + b3)
L3_drop = tf.nn.dropout(L3, keep_prob)

W4 = tf.Variable(tf.truncated_normal([24, 16], stddev=0.1))
b4 = tf.Variable(tf.zeros([1, 16]) + 0.1)
L4 = tf.nn.relu(tf.matmul(L3_drop, W4) + b4)
L4_drop = tf.nn.dropout(L4, keep_prob)

W5 = tf.Variable(tf.truncated_normal([16, 12], stddev=0.1))
b5 = tf.Variable(tf.zeros([1, 12]) + 0.1)
L5 = tf.nn.relu(tf.matmul(L4_drop, W5) + b5)
L5_drop = tf.nn.dropout(L5, keep_prob)

W6 = tf.Variable(tf.truncated_normal([12, 8], stddev=0.1))
b6 = tf.Variable(tf.zeros([1, 8]) + 0.1)
L6 = tf.nn.relu(tf.matmul(L5_drop, W6) + b6)
L6_drop = tf.nn.dropout(L6, keep_prob)

#预测层
W7 = tf.Variable(tf.truncated_normal([8, 1], stddev=0.1))
b7 = tf.Variable(tf.zeros([1,1]) + 0.1)
prediction = tf.matmul(L6_drop, W7) + b7

test_result = []
saver = tf.train.Saver()


with tf.Session() as sess:
    saver.restore(sess, 'my_model-4000')
    p = sess.run(prediction, feed_dict={x:test})

加载成功后就可以用已训练的模型预测啦!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值