Google dopamine 搜索框架算法 -py 语言-立哥开发

# Copy Right 2020 Jacky Zong. All rights reserved.
#coding=utf-8

"""Tests for dopamine.agents.rainbow.rainbow_agent.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from dopamine.agents.dqn import dqn_agent
from dopamine.agents.rainbow import rainbow_agent
from dopamine.discrete_domains import atari_lib
from dopamine.utils import test_utils
import numpy as np
import tensorflow as tf


class ProjectDistributionTest(tf.test.TestCase):

  def testInconsistentSupportsAndWeightsParameters(self):
    supports = tf.constant([[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]], dtype=tf.float32)
    weights = tf.constant(
        [[0.1, 0.2, 0.3, 0.2], [0.1, 0.2, 0.3, 0.2]], dtype=tf.float32)
    target_support = tf.constant([4, 5, 6, 7, 8], dtype=tf.float32)
    with self.assertRaisesRegexp(ValueError, 'are incompatible'):
      rainbow_agent.project_distribution(supports, weights, target_support)

  def testInconsistentSupportsAndWeightsWithPlaceholders(self):
    supports = [[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]]
    supports_ph = tf.compat.v1.placeholder(tf.float32, None)
    weights = [[0.1, 0.2, 0.3, 0.2], [0.1, 0.2, 0.3, 0.2]]
    weights_ph = tf.compat.v1.placeholder(tf.float32, None)
    target_support = [4, 5, 6, 7, 8]
    target_support_ph = tf.compat.v1.placeholder(tf.float32, None)
    projection = rainbow_agent.project_distribution(
        supports_ph, weights_ph, target_support_ph, validate_args=True)
    with self.test_session() as sess:
      tf.compat.v1.global_variables_initializer().run()
      with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                   'assertion failed'):
        sess.run(
            projection,
            feed_dict={
                supports_ph: supports,
                weights_ph: weights,
                target_support_ph: target_support
            })

  def testInconsistentSupportsAndTargetSupportParameters(self):
    supports = tf.constant([[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]], dtype=tf.float32)
    weights = tf.constant(
        [[0.1, 0.2, 0.3, 0.2, 0.2], [0.1, 0.2, 0.3, 0.2, 0.2]],
        dtype=tf.float32)
    target_support = tf.constant([4, 5, 6], dtype=tf.float32)
    with self.assertRaisesRegexp(ValueError, 'are incompatible'):
      rainbow_agent.project_distribution(supports, weights, target_support)

  def testInconsistentSupportsAndTargetSupportWithPlaceholders(self):
    supports = [[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]]
    supports_ph = tf.compat.v1.placeholder(tf.float32, None)
    weights = [[0.1, 0.2, 0.3, 0.2, 0.2], [0.1, 0.2, 0.3, 0.2, 0.2]]
    weights_ph = tf.compat.v1.placeholder(tf.float32, None)
    target_support = [4, 5, 6]
    target_support_ph = tf.compat.v1.placeholder(tf.float32, None)
    projection = rainbow_agent.project_distribution(
        supports_ph, weights_ph, target_support_ph, validate_args=True)
    with self.test_session() as sess:
      tf.compat.v1.global_variables_initializer().run()
      with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                   'assertion failed'):
        sess.run(
            projection,
            feed_dict={
                supports_ph: supports,
                weights_ph: weights,
                target_support_ph: target_support
            })
 
  def testZeroDimensionalTargetSupport(self):
    supports = tf.constant([[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]], dtype=tf.float32)
    weights = tf.constant(
        [[0.1, 0.2, 0.3, 0.2, 0.2], [0.1, 0.2, 0.3, 0.2, 0.2]],
        dtype=tf.float32)
    target_support = tf.constant(3, dtype=tf.float32)
    with self.assertRaisesRegexp(ValueError, 'Index out of range'):
      rainbow_agent.project_distribution(supports, weights, target_support)

  def testZeroDimensionalTargetSupportWithPlaceholders(self):
    supports = [[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]]
    supports_ph = tf.compat.v1.placeholder(tf.float32, None)
    weights = [[0.1, 0.2, 0.3, 0.2, 0.2], [0.1, 0.2, 0.3, 0.2, 0.2]]
    weights_ph = tf.compat.v1.placeholder(tf.float32, None)
    target_support = 3
    target_support_ph = tf.compat.v1.placeholder(tf.float32, None)
    projection = rainbow_agent.project_distribution(
        supports_ph, weights_ph, target_support_ph, validate_args=True)
    with self.test_session() as sess:
      tf.compat.v1.global_variables_initializer().run()
      with self.assertRaises(tf.errors.InvalidArgumentError):
        sess.run(
            projection,
            feed_dict={
                supports_ph: supports,
                weights_ph: weights,
                target_support_ph: target_support
            })

  def testMultiDimensionalTargetSupport(self):
    supports = tf.constant([[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]], dtype=tf.float32)
    weights = tf.constant(
        [[0.1, 0.2, 0.3, 0.2, 0.2], [0.1, 0.2, 0.3, 0.2, 0.2]],
        dtype=tf.float32)
    target_support = tf.constant([[3]], dtype=tf.float32)
    with self.assertRaisesRegexp(ValueError, 'out of bounds'):
      rainbow_agent.project_distribution(supports, weights, target_support)

  def testMultiDimensionalTargetSupportWithPlaceholders(self):
    supports = [[0, 2, 4, 6, 8], [3, 4, 5, 6, 7]]
    supports_ph = tf.compat.v1.placeholder(tf.float32, None)
    weights = [[0.1, 0.2, 0.3, 0.2, 0.2], [0.1, 0.2, 0.3, 0.2, 0.2]]
    weights_ph = tf.compat.v1.placeholder(tf.float32, None)
    target_support = [[3]]
    target_support_ph = tf.compat.v1.placeholder(tf.float32, None)
    projection = rainbow_agent.project_distribution(
        supports_ph, weights_ph, target_support_ph, validate_args=True)
    with self.test_session() as sess:
      tf.compat.v1.global_variables_initializer().run()
      with self.assertRaises(tf.errors.InvalidArgumentError):
        sess.run(
            projection,
            feed_dict={
                supports_ph: supports,
                weights_ph: weights,
                target_support_ph: target_support
            })

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值