项目场景:
python编写多个类,这些类之间有相互继承关系时如何实例化
问题描述
来源于强化学习经典问题-多臂老虎机。编写一个bandit类,然后作为参数传递给另一个类solver(注意不是继承),然后solver又作为父类被EpsilonGreddy继承,然后EpsilonGreddy又作为父类被decay_epsilon类继承,然后在创建decay_epsilon实例时遇到了困难
import numpy as np
import matplotlib.pyplot as plt
class multiband(object):
"""docstring for multiband"""
def __init__(self, k):
super(multiband, self).__init__()
self.k = k
self.probablity=np.random.uniform(size=k)
self.best_idx=np.argmax(self.probablity)
def step(self,l):
if np.random.uniform(size=1)>self.probablity[l]:
return 0
else:
return 1
class solver(object):
"""docstring for solver"""
def __init__(self, bandit):
super(solver, self).__init__()
self.bandit=bandit
# self.probablity=np.random.rand(self.bandit.k)
# self.best_idx=np.argmax(self.probablity)
self.count=np.zeros(self.bandit.k)
self.regrets=[]
self.strategy=[]
self.regret=0
def choose_strategy(self):
raise NotImplementedError
def update_regret(self,k):
self.strategy.append(k)
self.regret+=self.bandit.probablity[self.bandit.best_idx]-self.bandit.probablity[k]
self.regrets.append(self.regret)
def run(self,num_epoch):
self.num_epoch=num_epoch
for i in range(num_epoch):
k=self.choose_strategy()
self.update_regret(k)
self.count[k]+=1
class EpsilonGreedy(solver):
"""docstring for EpsilonGreedy"""
def __init__(self, bandit,epsilon=0.01,init_pro=0):
super(EpsilonGreedy, self).__init__(bandit)
self.epsilon=epsilon
self.estimate=np.array([init_pro]*self.bandit.k)
def choose_strategy(self):
if np.random.rand()>1-self.epsilon:
k=np.random.choice(np.arange(0,self.bandit.k-1))
else:
k=self.bandit.best_idx
r=self.bandit.step(k)
self.estimate[k]+=1.0/(self.count[k]+1)*(r-self.estimate[k])
return k
class decay_epsilon(EpsilonGreedy):
"""docstring for decay_episiloin"""
def __init__(self, EpsilonGreedy):
super(decay_epsilon, self).__init__()
self.total_count=0
def choose_strategy(self):
self.total_count+=1
if np.random.rand()>1-self.epsilon:
k=np.random.choice(np.arange(0,self.bandit.k-1))
else:
k=self.bandit.best_idx
self.epsilon=1.0/np.exp(self.total_count)
r=self.bandit.step(k)
self.estimate[k]+=1.0/(self.count[k]+1)*(r-self.estimate[k])
return k
def plot_results(solvers, solver_names):
"""生成累积懊悔随时间变化的图像。输入solvers是一个列表,列表中的每个元素是一种特定的策略。
而solver_names也是一个列表,存储每个策略的名称"""
for idx, solver in enumerate(solvers):
time_list = range(len(solver.regrets))
plt.plot(time_list, solver.regrets, label=solver_names[idx])
plt.xlabel('Time steps')
plt.ylabel('Cumulative regrets')
plt.title('%d-armed bandit' % solvers[0].bandit.k)
plt.legend()
plt.show()
np.random.seed(0)
epsilon=[0.01,0.05,0.1,0.2,0.5]
epsilon_greedy_solver=[EpsilonGreedy(multi,e) for e in epsilon]
for i in range(len(epsilon)):
epsilon_greedy_solver[i].run(5000)
pass
plot_results(epsilon_greedy_solver,[f"epsilon_greddy{e}" for e in epsilon])
np.random.seed(2)
epsilon=[0.01,0.05,0.1,0.2,0.5]
epsilon_greedy_solver=[EpsilonGreedy(multi,e) for e in epsilon]
decay_epsilons=[decay_epsilon(e) for e in epsilon_greedy_solver]
for i in range(len(decay_epsilons)):
decay_epsilons[i].run(5000)
plot_results(decay_epsilons,[f"decay_epsilon:{e}" for e in epsilon])
上述有问题的代码在于:
class decay_epsilon(EpsilonGreedy):
"""docstring for decay_episiloin"""
def __init__(self, EpsilonGreedy):
super(decay_epsilon, self).__init__() #继承错误
self.total_count=0
def choose_strategy(self):
self.total_count+=1
if np.random.rand()>1-self.epsilon:
k=np.random.choice(np.arange(0,self.bandit.k-1))
else:
k=self.bandit.best_idx
self.epsilon=1.0/np.exp(self.total_count)
r=self.bandit.step(k)
self.estimate[k]+=1.0/(self.count[k]+1)*(r-self.estimate[k])
return k
我在
decay_epsilon
类的构造函数中使用EpsilonGreedy
作为参数继承,这是不正确的。继承是通过类之间的关系建立的,而不是通过实例作为参数传递来实现的。正确的做法是让decay_epsilon
直接继承自EpsilonGreedy
,而不是通过传递实例来实现继承。下面是修正后的代码示例:
解决方案
class decay_epsilon(EpsilonGreedy): """docstring for decay_episiloin""" def __init__(self, EpsilonGreedy): super(decay_epsilon, self).__init__(EpsilonGreedy.bandit) self.total_count=0 def choose_strategy(self): self.total_count+=1 if np.random.rand()>1-self.epsilon: k=np.random.choice(np.arange(0,self.bandit.k-1)) else: k=self.bandit.best_idx self.epsilon=1.0/self.total_count r=self.bandit.step(k) self.estimate[k]+=1.0/(self.count[k]+1)*(r-self.estimate[k]) return k
修改后的代码仍然将EpsilonGreddy对象作为参数,但由于EpsilonGreddy也正好是decay_epsilon的父类,因此在使用super时要调用父类的构造函数,而父类的构造函数需要一个bandit参数,这正是传入的EpsilonGreddy对象的参数,因此需要使用EpsilonGreddy.bandit,如果传入的是bandit对象也可以,构造函数直接就是
def __init__(self, bandit):
super(decay_epsilon, self).__init__(bandit)