LSTM程序

4 篇文章 0 订阅
# -*- coding: utf-8 -*-
"""
Created on Tue Feb 18 13:44:21 2020

@author: hongyangneng
"""

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
from sklearn.preprocessing import StandardScaler
from sklearn import preprocessing
from sklearn import svm
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt

import random
import numpy as np
import pandas as pd
import tensorflow as tf
tf.reset_default_graph() 
#
## 2.创建特征列表。
#column_names = ['1', '2', '3','4','5','6','7','8','9','10','11','12','13','14','15','16','17','18','19','20','label1','label2','label3','label4','label5']
#datatest = pd.read_csv(r'D:\tensorflow\masterdegree\batteryfault_data\utest.csv', names = column_names )
#column_names = ['1', '2', '3','4','5','6','7','8','9','10','11','12','13','14','15','16','17','18','19','20','label1','label2','label3']
#datatest = pd.read_csv(r'D:\tensorflow\masterdegree\batteryfault_data\u2.csv', names = column_names )

column_names=np.arange(1, 406, 1)
#使用SVM分类创建特征列表
#column_names = ['1', '2', '3','4','5','6','7','8','9','10','11','12','13','14','15','16','17','18','19','20','label','label2']

##3.使用pandas.read_csv函数读取指定数据。使用SVM分类
datatest = pd.read_csv(r'D:\tensorflow\masterdegree\batteryfault_data\u200.csv', names = column_names )


##print(data.shape)
#
#data=datatest[column_names[0:200]]
#target=datatest[column_names[200:202]]



##print(data.shape)
data=datatest[column_names[0:400]]
target=datatest[column_names[400:405]]
#print(target)
#4.使用sklearn.cross_valiation里的train_test_split模块用于分割数据。

from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
# 随机采样25%的数据用于测试,剩下的75%用于构建训练集合。
X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.25, random_state=33)
print(y_train)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
#print(X_test)
#X_train.shape
total_train_batch, total_test_batch = X_train.shape[0], X_test.shape[0]
print(total_train_batch,)
print(total_test_batch,)

print(X_train.shape, y_train.shape,y_test.shape,)
print(type(X_train), type(y_train))
y_train=y_train.values.reshape(2625,5)
print(type(X_train), type(y_train),type(y_test))
 


n_batch = 100
n_step = 2
n_input = 200
n_output = 5
n_cell = 100
lr1 = 0.01
lr2 = 0.0001
lr3 = tf.Variable(0.01, dtype=tf.float32)

n_train = 100

lstm = LSTM(n_batch, n_step, n_input, n_output, n_cell,lr3)


fig_accuracy1 = np.zeros([1500])
fig_accuracy2 = np.zeros([1500])
fig_accuracy3 = np.zeros([100])
saver = tf.train.Saver()



with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for i in range(n_train):
        term = random.sample(range(total_train_batch), n_batch)
        #不再训练直接导入模型
#        sess.run(tf.assign(lr3, 0.01 * (0.995 ** i)))
#        sess.run(lstm.train_op, {lstm.x: X_train[term][
#                 :, :, np.newaxis].reshape(n_batch, n_step, n_input), lstm.y: y_train[term]})
        saver.restore(sess,'net/battery_net.ckpt')
        if i % 1 == 0:
            acc = np.ones((int(total_test_batch / n_batch), n_batch))
            for j in range(int(total_test_batch / n_batch)):
                acc[j] = (sess.run(lstm.acc,
                                   {lstm.x: X_test[j * n_batch: (j + 1) * n_batch][
                                       :, :, np.newaxis].reshape(n_batch, n_step, n_input),
                                    lstm.y: y_test[j * n_batch: (j + 1) * n_batch]}))
            print(i, acc.mean(), ...)
           
            fig_accuracy2[i]=acc.mean()
#        saver.save(sess,'net/battery_net.ckpt')
            
#
#
## 绘制曲线
#fig, ax1 = plt.subplots()
#
##lns = ax1.plot(np.arange(1500), fig_accuracy1, 'r', label="Accuracy")
#lns = ax1.plot(np.arange(100), fig_accuracy2, 'r', label="Accuracy")
##lns = ax1.plot(np.arange(1500), fig_accuracy3, 'r', label="Accuracy")
#
##label = [ "Accuracy1"]
#label = [ "Accuracy2"]
##label = [ "Accuracy3"]
#
#ax1.set_xlabel('iteration')
#ax1.set_ylabel('training accuracy')
#plt.legend(lns, label, loc=0)
#plt.show()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值