python类多层继承的关系

项目场景:

python编写多个类,这些类之间有相互继承关系时如何实例化


问题描述

来源于强化学习经典问题-多臂老虎机。编写一个bandit类,然后作为参数传递给另一个类solver(注意不是继承),然后solver又作为父类被EpsilonGreddy继承,然后EpsilonGreddy又作为父类被decay_epsilon类继承,然后在创建decay_epsilon实例时遇到了困难

import numpy as np
import matplotlib.pyplot as plt

class multiband(object):
    """docstring for multiband"""
    def __init__(self, k):
        super(multiband, self).__init__()
        self.k = k
        self.probablity=np.random.uniform(size=k)
        self.best_idx=np.argmax(self.probablity)

    def step(self,l):
        if np.random.uniform(size=1)>self.probablity[l]:
            return 0
        else:
            return 1


class solver(object):
    """docstring for solver"""
    def __init__(self, bandit):
        super(solver, self).__init__()
        self.bandit=bandit
        # self.probablity=np.random.rand(self.bandit.k)
        # self.best_idx=np.argmax(self.probablity)
        self.count=np.zeros(self.bandit.k)
        self.regrets=[]
        self.strategy=[]
        self.regret=0
    def choose_strategy(self):
        raise NotImplementedError

    def update_regret(self,k):
        self.strategy.append(k)
        self.regret+=self.bandit.probablity[self.bandit.best_idx]-self.bandit.probablity[k]
        self.regrets.append(self.regret)

    def run(self,num_epoch):
        self.num_epoch=num_epoch
        for i in range(num_epoch):
            k=self.choose_strategy()
            self.update_regret(k)
            self.count[k]+=1

class EpsilonGreedy(solver):
    """docstring for EpsilonGreedy"""
    def __init__(self, bandit,epsilon=0.01,init_pro=0):
        super(EpsilonGreedy, self).__init__(bandit)
        self.epsilon=epsilon
        self.estimate=np.array([init_pro]*self.bandit.k)

    def choose_strategy(self):
        if np.random.rand()>1-self.epsilon:
            k=np.random.choice(np.arange(0,self.bandit.k-1))
        else:
            k=self.bandit.best_idx
        r=self.bandit.step(k)
        self.estimate[k]+=1.0/(self.count[k]+1)*(r-self.estimate[k])
        return k

class decay_epsilon(EpsilonGreedy):
    """docstring for decay_episiloin"""
    def __init__(self, EpsilonGreedy):
        super(decay_epsilon, self).__init__()
        self.total_count=0
    
    def choose_strategy(self):
        self.total_count+=1
        if np.random.rand()>1-self.epsilon:
            k=np.random.choice(np.arange(0,self.bandit.k-1))
        else:
            k=self.bandit.best_idx

        self.epsilon=1.0/np.exp(self.total_count)

        r=self.bandit.step(k)
        self.estimate[k]+=1.0/(self.count[k]+1)*(r-self.estimate[k])
        return k


def plot_results(solvers, solver_names):
    """生成累积懊悔随时间变化的图像。输入solvers是一个列表,列表中的每个元素是一种特定的策略。
    而solver_names也是一个列表,存储每个策略的名称"""
    for idx, solver in enumerate(solvers):
        time_list = range(len(solver.regrets))
        plt.plot(time_list, solver.regrets, label=solver_names[idx])
    plt.xlabel('Time steps')
    plt.ylabel('Cumulative regrets')
    plt.title('%d-armed bandit' % solvers[0].bandit.k)
    plt.legend()
    plt.show()
        

np.random.seed(0)
epsilon=[0.01,0.05,0.1,0.2,0.5]
epsilon_greedy_solver=[EpsilonGreedy(multi,e) for e in epsilon]
for i in range(len(epsilon)):
    epsilon_greedy_solver[i].run(5000)
    pass
plot_results(epsilon_greedy_solver,[f"epsilon_greddy{e}" for e in epsilon])

np.random.seed(2)
epsilon=[0.01,0.05,0.1,0.2,0.5]
epsilon_greedy_solver=[EpsilonGreedy(multi,e) for e in epsilon]
decay_epsilons=[decay_epsilon(e) for e in epsilon_greedy_solver]
for i in range(len(decay_epsilons)):
    decay_epsilons[i].run(5000)

plot_results(decay_epsilons,[f"decay_epsilon:{e}" for e in epsilon])

上述有问题的代码在于:

class decay_epsilon(EpsilonGreedy):
    """docstring for decay_episiloin"""
    def __init__(self, EpsilonGreedy):
        super(decay_epsilon, self).__init__() #继承错误 
        self.total_count=0
    
    def choose_strategy(self):
        self.total_count+=1
        if np.random.rand()>1-self.epsilon:
            k=np.random.choice(np.arange(0,self.bandit.k-1))
        else:
            k=self.bandit.best_idx

        self.epsilon=1.0/np.exp(self.total_count)

        r=self.bandit.step(k)
        self.estimate[k]+=1.0/(self.count[k]+1)*(r-self.estimate[k])
        return k


原因分析:

我在 decay_epsilon 类的构造函数中使用 EpsilonGreedy 作为参数继承,这是不正确的。继承是通过类之间的关系建立的,而不是通过实例作为参数传递来实现的。正确的做法是让 decay_epsilon 直接继承自 EpsilonGreedy,而不是通过传递实例来实现继承。下面是修正后的代码示例:


解决方案

class decay_epsilon(EpsilonGreedy):
    """docstring for decay_episiloin"""
    def __init__(self, EpsilonGreedy):
        super(decay_epsilon, self).__init__(EpsilonGreedy.bandit)
        self.total_count=0
    
    def choose_strategy(self):
        self.total_count+=1
        if np.random.rand()>1-self.epsilon:
            k=np.random.choice(np.arange(0,self.bandit.k-1))
        else:
            k=self.bandit.best_idx
        
        self.epsilon=1.0/self.total_count

        r=self.bandit.step(k)
        self.estimate[k]+=1.0/(self.count[k]+1)*(r-self.estimate[k])
        return k

修改后的代码仍然将EpsilonGreddy对象作为参数,但由于EpsilonGreddy也正好是decay_epsilon的父类,因此在使用super时要调用父类的构造函数,而父类的构造函数需要一个bandit参数,这正是传入的EpsilonGreddy对象的参数,因此需要使用EpsilonGreddy.bandit,如果传入的是bandit对象也可以,构造函数直接就是

    def __init__(self, bandit):
        super(decay_epsilon, self).__init__(bandit)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值