# -*- coding: utf-8 -*-importrandomimportnumpyasnpimportenvimportgymfromcollectionsimportdequefromkeras.modelsimportSequentialfromkeras.layersimportDensefromkeras.optimizersimportAdamEPISODES=1000classDQNAgent:def__init__(self,state_size,action_size):self.state_size=state_size
self.action_size=action_size
self.memory=deque(maxlen=2000)#self.gamma = 0.95 # discount rateself.gamma=0# discount rateself.epsilon=0.5# exploration rateself.epsilon_min=0.01self.epsilon_decay=0.985self.learning_rate=0.001self.model=self._build_model()def_build_model(self):# Neural Net for Deep-Q learning Modelmodel=Sequential()model.add(Dense(2,input_dim=self.state_size,activation='relu'))model.add(Dense(2,activation='relu'))model.add(Dense(2,activation='relu'))model.add(Dense(self.action_size,activation='relu'))model.compile(loss='mse',optimizer=Adam(lr=self.learning_rate))returnmodeldefremember(self,state,action,reward,next_state,done):self.memory.append((state,action,reward,next_state,done))defact(self,state):ifnp.random.rand()<=self.epsilon:returnrandom.randrange(self.action_size)action=np.argmax(self.model.predict(state)[0])returnactiondefreplay(self,batch_size):minibatch=random.sample(self.memory,batch_size)forstate,action,reward,next_state,doneinminibatch:print("stating")target=rewardifnotdone:target=reward+self.gamma*np.amax(self.model.predict(next_state)[0])target_f=self.model.predict(state)target_f[0][action]=targetprint("Reward: "+str(reward))print("Target: "+str(target))print(action)print(self.gamma*np.amax(self.model.predict(next_state)[0]))print(state)print(target_f)self.model.fit(state,target_f,epochs=1,verbose=0)ifself.epsilon>self.epsilon_min:self.epsilon*=self.epsilon_decaydefload(self,name):self.model.load_weights(name)defsave(self,name):self.model.save_weights(name)if__name__=="__main__":state_size=2action_size=2#timeRange = 1440timeRange=998agent=DQNAgent(state_size,action_size)# agent.load("./save/cartpole-master.h5")done=Falsebatch_size=500totalScore=0foreinrange(EPISODES):env_state=env.GameState()env_state.reset()state=env_state.step(0)[1]#state = np.reshape(state, [1, state_size])totalreward=0fortimeinrange(timeRange):#if time==timeRange-1:#done = Trueaction=agent.act(state)reward,next_state=env_state.step(action)totalreward+=reward
next_state=np.reshape(next_state,[1,state_size])agent.remember(state,action,reward,next_state,done)state=next_stateifdone:print("episode: {}/{}, score: {}, e: {:.2}".format(e,EPISODES,env_state.money+env_state.shares*env_state.sharePrice,agent.epsilon))totalScore+=env_state.money+env_state.shares*env_state.sharePrice
done=Falsebreakiflen(agent.memory)>batch_size:print("replaying")print("Total reward: "+str(totalreward))totalreward=0#print(agent.memory)agent.replay(batch_size)# if e % 10 == 0:# agent.save("./save/cartpole.h5")