纯python 实现lstm,进行点的跟踪。加载tensorflow训练得到的权重,实现追踪。c++版本的在另外一篇博客里。
// An highlighted block
import numpy as np
from pandas import read_excel
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
# global h_tm_c, h_tm_f, h_tm_i, h_tm_o, c_tm
# global h_tm_c2, h_tm_f2, h_tm_i2, h_tm_o2, c_tm2
def hard_sigmoid(x):
x = 0.2 * x + 0.5
x[x < 0] = 0
x[x > 1] = 1
return x
def lstm_keras_verify(inputs,ten2,ten3):
global h_tm_c, h_tm_f, h_tm_i, h_tm_o, c_tm
global h_tm_c2, h_tm_f2, h_tm_i2, h_tm_o2, c_tm2
if ten2 == 0:
h_tm_c, h_tm_f, h_tm_i, h_tm_o, c_tm = None,None,None,None,None
h_tm_c2, h_tm_f2, h_tm_i2, h_tm_o2, c_tm2 = None, None, None, None, None
kernel_i = np.array([-0.34339943528175354,
1.766756296157837,
0.7450317740440369,
-0.829913854598999,
0.5050585269927979,
0.5951665043830872,
-0.12462729960680008,
-0.057644475251436234,
-0.195032998919487,
0.9801527857780457])
kernel_f = np.array([-0.2959004342556,
0.1878744661808014,
0.1974334567785263,
-0.42833104729652405,
0.655572772026062,
0.5754011273384094,
0.20625288784503937,
0.368974894285202,
-0.32602018117904663,
0.276435434818267])
kernel_c = np.array([0.09515336155891418,
-0.5731109976768494,
-0.4054526388645172,
-0.314892441034317,
0.3126268684864044,
-0.5319279432296753,
0.653582751750946,
0.49748483300209045,
0.20359081029891968,
0.5867465138435364])
kernel_o = np.array([0.20324330031871796,
1.5852749347686768,
0.6022287011146545,
-1.1601568460464478,
0.7637105584144592,
0.7074271440505981,
0.6657699942588806,
-0.3173244595527649,
-0.3717532157897949,
1.0292516946792603])
kernel_i2 = np.array([0.29539164900779724,
-0.5408189296722412,
0.31895196437835693,
-0.3692389726638794,
0.5499736070632935,
0.0356333963572979,
0.6757100820541382,
0.930941641330719,
0.4551529884338379,
0.07838062196969986])
kernel_f2 = np.array([0.1885366439819336,
-0.5314029455184937,
-0.25439608097076416,
-0.32862451672554016,
-0.026900850236415863,
-0.6118189096450806,
-0.8291504979133606,
-0.10796643793582916,
0.3633717894554138,
-0.5390552878379822])
kernel_c2 = np.array([ -0.042067818343639374,
0.7635902166366577,
-0.1499466896057129,
0.7719531059265137,
0.4475833773612976,
-0.006783284712582827,
-0.1789802610874176,
-0.5963394045829773,
-0.5599965453147888,
-0.37145528197288513])
kernel_o2 = np.array([ -0.18046832084655762,
-0.17623309791088104,
0.20015233755111694,
-0.6863264441490173,
0.0444098562002182,
0.42466920614242554,
0.31134355068206787,
0.5258762240409851,
0.6071751117706299,
-0.7286037802696228])
recurrent_kernel_i = np.array([[-0.41451895236968994,
2.3002400398254395,
-0.2820678651332855,
-0.45961064100265503,
0.505663275718689,
0.7919554114341736,
0.20269405841827393,
0.8496864438056946,
0.08628702908754349,
0.7238653302192688], [-0.9789332151412964,
-1.5582021474838257,
-1.3366812467575073,
-0.27708300948143005,
0.4506531357765198,
-1.467345952987671,
-0.5549973249435425,
1.5742740631103516,
0.2451542168855667,
-0.940199077129364], [0.4405961036682129,
-1.5445444583892822,
0.05049801990389824,
-0.35451313853263855,
-0.269658625125885,
-0.638664722442627,
-0.789127767086029,
-0.7010314464569092,
-0.13462814688682556,
-0.6025670766830444], [-0.887526273727417,
0.4103408753871918,
-0.4497901499271393,
-0.2633923292160034,
-0.22475440800189972,
1.236531376838684,
0.6960248947143555,
0.33017680048942566,
-0.5554746985435486,
1.7138985395431519], [ 0.12573093175888062,
0.15226393938064575,
0.35927459597587585,
-0.2610260248184204,
0.1653788834810257,
0.8279508352279663,
1.0212059020996094,
0.820050060749054,
0.5119935870170593,
0.3915463089942932], [-0.001828429289162159,
-3.9691922664642334,
0.37881338596343994,
0.10328131914138794,
-0.296207994222641,
-1.0051754713058472,
-0.8751426339149475,
-0.4393799901008606,
-0.5824257135391235,
-0.9112383723258972], [0.14419855177402496,
0.07607676088809967,
0.5626516938209534,
-1.290927529335022,
-0.0973503440618515,
0.5256897807121277,
0.30444854497909546,
-0.9554680585861206,
-0.2767266035079956,
0.2846880555152893], [ 0.20677947998046875,
2.7582507133483887,
-0.2818203270435333,
0.9349485635757446,
-0.985670268535614,
0.8378878831863403,
-0.7977424263954163,
-0.6354966163635254,
-0.5764774084091187,
0.5060479044914246], [-0.14697682857513428,
1.3863033056259155,
-0.14390596747398376,
0.8854043483734131,
-0.9272205829620361,
-0.7263381481170654,
-1.2763465642929077,
-1.5187703371047974,
-0.8858398795127869,
0.05383571982383728], [0.5134013891220093,
2.266584634780884,
0.4238872528076172,
-0.03581789508461952,
0.1086997240781784,
0.645852267742157,
1.1289607286453247,
0.10689905285835266,
0.2679310739040375,
0.67705899477005]])
recurrent_kernel_f = np.array([[-0.16183863580226898,
0.7472390532493591,
-0.334164023399353,
-0.1287859082221985,
0.32973605394363403,
-0.1532825231552124,
-0.18556086719036102,
1.174559235572815,
0.315372109413147,
-0.14612999558448792], [-1.0360535383224487,
0.19038058817386627,
-0.9956535696983337,
-0.28531670570373535,
0.7210497260093689,
-0.02345925383269787,
-1.0118887424468994,
0.5574779510498047,
0.3704586923122406,
-0.6031738519668579], [ 0.3754199743270874,
-0.014456510543823242,
-0.06675058603286743,
0.17784777283668518,
-0.16492195427417755,
0.2980993986129761,
0.25578248500823975,
-0.7040256261825562,
-0.5165166258811951,
0.1266956329345703], [ -0.6861728429794312,
-0.16793103516101837,
-0.271343469619751,
-0.27307748794555664,
-0.19215670228004456,
0.27285122871398926,
0.025628695264458656,
-0.2498171329498291,
-0.571029543876648,
0.3341529667377472], [ -0.2412928342819214,
0.13404031097888947,
0.027914635837078094,
-0.3998990058898926,
0.14130471646785736,
-0.23286490142345428,
-0.5493660569190979,
0.20289260149002075,
0.6077360510826111,
-0.41169750690460205], [0.10601703077554703,
-0.4763187766075134,
0.0049462090246379375,
0.5886510014533997,
-0.2985691428184509,
0.17912009358406067,
0.346873015165329,
-0.650505542755127,
-0.385824054479599,
0.09396341443061829], [0.5001235008239746,
1.339636206626892,
0.6780133247375488,
-1.1757398843765259,
-0.20317663252353668,
0.5733743906021118,
-0.506499171257019,
-0.3316236138343811,
-0.23696570098400116,
0.017315169796347618], [-0.008928589522838593,
0.7925891280174255,
0.7797060012817383,
0.8483789563179016,
-0.7849624156951904,
1.0222151279449463,
0.3750150799751282,
0.6228334307670593,
-1.3550268411636353,
0.3627028465270996], [ -0.18932075798511505,
0.5054185390472412,
0.21752338111400604,
0.5696843862533569,
-1.0764824151992798,
0.8123449087142944,
0.3687470555305481,
-0.10713357478380203,
-0.8887754082679749,
-0.02434460259974003], [ -0.17165663838386536,
0.7209696769714355,
0.6251261830329895,
0.013831086456775665,
0.06497514992952347,
0.4622625410556793,
0.09293515235185623,
0.20599240064620972,
-0.0917106419801712,
-0.08447670191526413]])
recurrent_kernel_c = np.array([[ -0.6100238561630249,
-0.12698210775852203,
0.4954182505607605,
0.15199264883995056,
0.35257288813591003,
-0.12464402616024017,
-0.44923147559165955,
0.040395427495241165,
0.7752857208251953,
-0.05500243604183197], [0.4752281606197357,
-0.2517543435096741,
0.16683556139469147,
0.2313675582408905,
0.7824580669403076,
-0.07187685370445251,
0.19583779573440552,
-0.15934820473194122,
-0.09073260426521301,
0.1494336724281311], [-0.07115137577056885,
-0.27783897519111633,
-0.1414336860179901,
-0.5170375108718872,
-0.054542649537324905,
-0.320780485868454,
0.44443681836128235,
0.16018202900886536,
0.16488415002822876,
0.26717546582221985], [0.23076167702674866,
-0.558992326259613,
-0.28554731607437134,
0.40353721380233765,
0.07054921239614487,
-0.28317591547966003,
0.17094585299491882,
0.252930223941803,
0.3688906729221344,
0.2711980938911438], [ 0.14932553470134735,
0.056401338428258896,
-0.25233572721481323,
0.4610925614833832,
0.22263102233409882,
0.02338065765798092,
-0.040763139724731445,
-0.29332226514816284,
-0.2517569959163666,
-0.07778314501047134], [-0.07450161874294281,
0.17594310641288757,
-0.09124835580587387,
-0.9320527911186218,
0.06922359019517899,
-0.1493844985961914,
0.593464732170105,
0.19774796068668365,
0.058623287826776505,
0.11784616112709045], [ 0.3252342641353607,
-0.05548950284719467,
-0.08142340183258057,
-0.12902313470840454,
-0.9895947575569153,
-0.4129822254180908,
-0.3162073493003845,
0.5147839188575745,
0.3044321835041046,
-0.10074173659086227], [ 0.1463569700717926,
0.006250813137739897,
0.39591124653816223,
0.5136163830757141,
-0.5497452616691589,
-0.16999530792236328,
-0.43260541558265686,
0.07086426019668579,
-0.40135759115219116,
-0.4523780345916748], [0.2282596379518509,
-0.142124742269516,
0.22567713260650635,
0.13072924315929413,
-0.5392906069755554,
-0.18509604036808014,
-0.02913133054971695,
0.206122487783432,
-0.11300422251224518,
-0.2983923852443695], [ 0.27763694524765015,
-0.08438488095998764,
-0.2187069207429886,
0.2885969281196594,
-0.20823447406291962,
-0.020448746159672737,
-0.40587231516838074,
0.271494060754776,
0.09722976386547089,
-0.4638107120990753]])
recurrent_kernel_o = np.array([[0.08256174623966217,
1.0839580297470093,
-0.3347654640674591,
-0.17489440739154816,
-0.03872399777173996,
0.1976941078901291,
0.3668980002403259,
0.6379971504211426,
0.025410614907741547,
0.03770061582326889], [-0.8977024555206299,
-0.510123610496521,
-1.548133373260498,
-0.21091605722904205,
-0.12713493406772614,
-1.005964756011963,
-0.3703922629356384,
1.4118517637252808,
0.20945164561271667,
-1.00154447555542], [-0.005945556331425905,
-0.7805576920509338,
0.016335049644112587,
-0.011125431396067142,
0.2580377757549286,
-0.290608286857605,
-0.8096177577972412,
-0.5068511366844177,
-0.08132117241621017,
-0.3691478967666626], [ -1.0412406921386719,
1.1914699077606201,
-0.8255971670150757,
-0.15991421043872833,
0.04529210552573204,
0.21786856651306152,
0.797028660774231,
0.5845527648925781,
-0.6345934867858887,
0.8253769278526306], [ 0.3824268877506256,
0.367237389087677,
-0.03931231051683426,
-0.2734659016132355,
0.06321419030427933,
0.7792273759841919,
0.872334361076355,
0.8914680480957031,
0.5063027739524841,
-0.02294449880719185], [0.17051495611667633,
-3.346440315246582,
-0.1220833957195282,
0.16170988976955414,
0.23057465255260468,
-0.5420761108398438,
-0.3005651533603668,
-0.20853359997272491,
-0.27472618222236633,
-0.6494842767715454], [0.3111186623573303,
-0.24375362694263458,
1.1926289796829224,
-1.8478845357894897,
0.37887048721313477,
0.7220742106437683,
0.51273512840271,
-0.3570004403591156,
-1.0175435543060303,
0.3988095819950104], [ -0.0040112873539328575,
1.8080674409866333,
0.5294325947761536,
0.5869970917701721,
-0.2983153164386749,
0.87757807970047,
-0.2792051136493683,
-1.231807827949524,
-1.0860964059829712,
1.0274618864059448], [-0.06216314807534218,
0.7327183485031128,
0.11892963945865631,
0.6781867742538452,
-0.6871398687362671,
0.319871187210083,
-1.1046788692474365,
-1.0522805452346802,
-1.1571115255355835,
0.4382123649120331], [-0.16790947318077087,
1.1358662843704224,
0.5037592649459839,
0.175055593252182,
-0.013465126045048237,
0.46570730209350586,
0.8169551491737366,
0.3778005540370941,
0.11658983677625656,
0.3484398424625397]])
bias_i = np.array([-0.446609765291214,
-0.11702976375818253,
-0.3065964877605438,
-0.025546252727508545,
-0.049140945076942444,
-0.001070813275873661,
0.612964928150177,
0.16031958162784576,
-0.0778871700167656,
0.15741103887557983])
bias_f = np.array([ 0.31866440176963806,
0.4428155720233917,
0.5675423741340637,
0.8959832191467285,
0.9183741807937622,
0.499959260225296,
0.6026825904846191,
0.5548291206359863,
0.5360163450241089,
0.5315966010093689])
bias_c = np.array([-0.1425662487745285,
0.05167638882994652,
0.1590413600206375,
-0.1351909041404724,
-0.006993392016738653,
0.26650679111480713,
-0.0693528801202774,
0.00884726271033287,
0.17367465794086456,
-0.0636242926120758])
bias_o = np.array([-0.47558993101119995,
0.41542863845825195,
-0.28937411308288574,
-0.12115256488323212,
-0.10851451754570007,
-0.03444031625986099,
0.7349303960800171,
0.2082076072692871,
-0.13617178797721863,
0.07746882736682892])
# step 1 计算W * x
x_i = inputs[0] * kernel_i
x_f = inputs[0] * kernel_f
x_c = inputs[0] * kernel_c
x_o = inputs[0] * kernel_o
x_i2 = inputs[1] * kernel_i2
x_f2 = inputs[1] * kernel_f2
x_c2 = inputs[1] * kernel_c2
x_o2= inputs[1] * kernel_o2
# step 2 加上bias
x_i += bias_i
x_f += bias_f
x_c += bias_c
x_o += bias_o
x_i2 += bias_i
x_f2 += bias_f
x_c2 += bias_c
x_o2 += bias_o
# step 3 计算
if not isinstance(h_tm_i, np.ndarray):
h_tm_i = np.zeros((1, 10))
h_tm_o = np.zeros((1, 10))
h_tm_f = np.zeros((1, 10))
h_tm_c = np.zeros((1, 10))
c_tm = np.zeros((1, 10))
h_tm_i2 = np.zeros((1, 10))
h_tm_o2 = np.zeros((1, 10))
h_tm_f2 = np.zeros((1, 10))
h_tm_c2 = np.zeros((1, 10))
c_tm2 = np.zeros((1, 10))
i = hard_sigmoid(x_i + np.dot(h_tm_i, recurrent_kernel_i))
f = hard_sigmoid(x_f + np.dot(h_tm_f, recurrent_kernel_f))
c = f * c_tm + i * np.tanh(x_c + np.dot(h_tm_c, recurrent_kernel_c))
o = hard_sigmoid(x_o + np.dot(h_tm_o, recurrent_kernel_o))
i2 = hard_sigmoid(x_i2 + np.dot(h_tm_i2, recurrent_kernel_i))
f2 = hard_sigmoid(x_f2 + np.dot(h_tm_f2, recurrent_kernel_f))
c2 = f2 * c_tm2 + i2 * np.tanh(x_c2 + np.dot(h_tm_c2, recurrent_kernel_c))
o2 = hard_sigmoid(x_o2 + np.dot(h_tm_o2, recurrent_kernel_o))
h = o * np.tanh(c)
h2 = o2 * np.tanh(c2)
h_tm_c = h_tm_f = h_tm_o = h_tm_i = h
c_tm = c
h_tm_c2 = h_tm_f2 = h_tm_o2 = h_tm_i2 = h2
c_tm2 = c2
dense_weights = np.array([[0.25596046447753906], [-0.6001706123352051], [ -0.7026602625846863], [-0.6046009659767151], [ 0.7513906359672546], [-0.6343705654144287], [1.1254067420959473], [0.6913528442382812], [0.3435429632663727], [ 0.7804898023605347]])
dense_bias = np.array([0.11426012963056564])
dense_weights2 = np.array([[0.476622074842453], [-0.05293620750308037], [-0.31866875290870667], [0.7019630670547485], [0.3929026424884796], [-0.11945560574531555], [0.2096286118030548], [-0.14515486359596252], [-0.7117556929588318], [ -0.36956626176834106]])
dense_bias2 = np.array([0.22602976858615875])
y = np.dot(h, dense_weights) + dense_bias
y2 = np.dot(h2, dense_weights2) + dense_bias2
y = y.flatten()
y2 = y2.flatten()
y = y.squeeze()
y2 = y2.squeeze()
y = np.array([y,y2])
#y = scaler.inverse_transform(np.array(y,y2))
#print(y, y2)
return y
def create_dataset(dataset, look_back=1):
dataX, dataY = [], []
predict_len = 5
for i in range(len(dataset)-look_back-1 -predict_len):
a = dataset[i:(i+look_back), 0:2]
dataX.append(a)
dataY.append(dataset[i + look_back + predict_len, 0:2])
return np.array(dataX), np.array(dataY)
if __name__ == "__main__":
dataframe = read_excel(r'C:\Users\gj7520\Desktop\pythob_files\file_select\data4\data_select.xls', usecols=[0, 1], skipfooter=3)
dataset = dataframe.values
# 将整型变为float
dataset = dataset.astype('float32')
# fix random seed for reproducibility
np.random.seed(7)
scaler = MinMaxScaler(feature_range=(0, 1))
dataset = scaler.fit_transform(dataset)
# split into train and test sets
#train_size = int(len(dataset) * 0.8)
#test_size = len(dataset) - train_size
train_size = 0
test_size = len(dataset)
train, test = dataset[0:train_size, :2], dataset[train_size:len(dataset), :2]
# use this function to prepare the train and test datasets for modeling
look_back = 5
trainX, trainY = create_dataset(train, look_back)
testX, testY = create_dataset(test, look_back)
arr = []
arr_train = []
arr2 = []
arr3 = []
for j in range(len(testX) - 10):
for i in range(0, len(testX[j])):
y = lstm_keras_verify(testX[j][i], i,j)
if i == len(testX[j]) - 1:
arr.append(y)
if( testX[j+5][0][1] > 0.02):
#arr2.append((y[0] - testX[j + 5][0][0])/testX[j+5][0][0])
arr3.append((y[1] - testX[j + 5][0][1])/testX[j+5][0][1])
if ((y[1] - testX[j + 5][0][1])/testX[j+5][0][1] > 30):
a = 1;
arr = np.array(arr)
#arr = scaler.inverse_transform(arr)
#arr2 = np.array(arr2)
arr3 = np.array(arr3)
#dataset_inve = scaler.inverse_transform(dataset)
#testY = scaler.inverse_transform(testY)
# shift train predictions for plotting
# #
# for j in range(len(trainX)):
# for i in range(0, len(trainX[j])):
# y2 = lstm_keras_verify(trainX[j][i], i,j)
# if i == len(trainX[j]) - 1:
# arr_train.append(y2)
# arr_train = np.array(arr_train)
#
# arr_train = scaler.inverse_transform(arr_train)
# trainY = scaler.inverse_transform(trainY)
# predict_len = 5
#
# # shift test predictions for plotting
# plt.plot(arr_train)
# plt.plot(trainY)
# plt.plot(testY)
# plt.plot(arr)
plt.plot(arr3)
# plt.plot(dataset)
# for i in range(0, len(testX[1])):
# y3 = lstm_keras_verify(testX[1][i], i,0)
# if(i == len(testX[1]) - 1):
# print(y3)
# print(trainY[1])
# print(testX[1])
plt.show()