dqn系列梳理_莫烦python强化学习系列-DQN学习(代码)

import numpy as np

import pandas as pd

import tensorflow as tf

np.random.seed(1)

tf.set_random_seed(1)

# Deep Q Network off-policy

class DeepQNetwork:

def __init__(

self,

n_actions,

n_features,

learning_rate=0.01,

reward_decay=0.9,

e_greedy=0.9,

replace_target_iter=300,

memory_size=500,

batch_size=32,

e_greedy_increment=None,

output_graph=False,

):

self.n_actions = n_actions

self.n_features = n_features

self.lr = learning_rate

self.gamma = reward_decay

self.epsilon_max = e_greedy

self.replace_target_iter = replace_target_iter

self.memory_size = memory_size

self.batch_size = batch_size

self.epsilon_increment = e_greedy_increment

self.epsilon = 0 if e_greedy_increment is not None else self.epsilon_max

# total learning step

self.learn_step_counter = 0

# initialize zero memory [s, a, r, s_]

self.memory = np.zeros((self.memory_size, n_features * 2 + 2))

# consist of [target_net, evaluate_net]

self._build_net()

#tf.get_collection(key, scope=None)

#用来获取一个名称是‘key’的集合中的所有元素,返回的是一个列表

t_params = tf.get_collection('target_net_params')

e_params = tf.get_collection('eval_net_params')

self.replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]

self.sess = tf.Session()

if output_graph:

# $ tensorboard --logdir=logs

# tf.train.SummaryWriter soon be deprecated, use following

tf.summary.FileWriter("logs/", self.sess.graph)

self.sess.run(tf.global_variables_initializer())

self.cost_his = []

def _build_net(self):

# ---------------

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值