在目前的学习中 DQN是比较牛的
专家也说 强化学习接近通用人工智能
Q_learning 列出了可能的action 和state对应的Q数值,然后在某个状态得到后,只需按照该状态下的最大Q值一直动作下去,就能得到最大的收益
但是如果q_lable不足以写出那么多状态和动作
也就是没有一个表格可以存储足够的Q值
那么就可以和深度学习结合在一起了,所谓的深度学习,就是根据输入的状态(在愤怒的小鸟中就是输入图像)然后经过卷积全连接得到了动作对应的Q数值,那么网络可以得到Q数值
因此可以形成( st at reword st+1 是否结束)等数组就可以通过深度学习得到了
经过训练,一个神经网络负责更新真实Q数值 r+a(Q(s* a*))
另外一哥神经网络负责得到逼近的Q值 Q(s a)
那么两个的差距就是LOSS目标
然后梯度清零 单项迭代可以使得loss数值不断的降低 部分的结果如下所示
然后保存训练的参数
进而当测试的时候 可以调用其中的参数 进行测试 得到动作数值 执行动作 进而不断的得到reword的参数数值
图 首次的愤怒小鸟图 和标准化的结果图
部分代码注释
game_state = FlappyBird()##随机生成小鸟的图像
plt.figure()
image, reward, terminal = game_state.next_frame(0)##得到了下一帧的图像
plt.subplot(121)
plt.imshow(image)
image = pre_processing(image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size, opt.image_size)##进行预处理
image = torch.from_numpy(image)
data=image
#plt.imshow(data.numpy()[:, :, 0].squeeze(), cmap=plt.cm.gray_r)
plt.subplot(122)
plt.imshow(image.T)
plt.title('原始的')
plt.show()
if torch.cuda.is_available():
model.cuda()
image = image.cuda()
state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]
#print('statesize={}'.format(state.shape)) 1 4 84 84
replay_memory = []
iter = 0
while iter < opt.num_iters:
prediction = model(state)[0]##获得得到的第一个数值 也就是Q表格当中的q数值 也就是在state采取action的q数值
# Exploration or exploitation
epsilon = opt.final_epsilon + (
(opt.num_iters - iter) * (opt.initial_epsilon - opt.final_epsilon) / opt.num_iters)
u = random()
random_action = u <= epsilon
if random_action:
print("Perform a random action")
action = randint(0, 1)
else:
action = torch.argmax(prediction).item()
#print('randomeacton={},//epsilon={},//prediction={},//action={}'.format(random_action,epsilon,prediction,action ))
next_image, reward, terminal = game_state.next_frame(action)
next_image = pre_processing(next_image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size,
opt.image_size)
next_image = torch.from_numpy(next_image)
if torch.cuda.is_available():
next_image = next_image.cuda()
next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]
replay_memory.append([state, action, reward, next_state, terminal])
if len(replay_memory) > opt.replay_memory_size:## default=50000 如果超过这个长度的化 那么就不接下来继续存储下去了
del replay_memory[0]
batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))##s随机选择batchsize大小的快,如果剩下的快不足够batchsize
#的大小 那么就剩下的样本都取出来