import numpy as np
from scipy import stats
import random
import matplotlib.pyplot as plt
n = 10
arms = np.random.randn(n)
eps = 0.1
def reward(prob):
rwd = 0
for i in range(10):
if random.random() < prob:
rwd += 1
return rwd
av = np.array([np.random.randint(0,(n+1)), 0]).reshape(1,2)
def bestArm(a):
bestLot = 0
bestMean = 0
for u in a:
avg = np.mean(a[np.where(a[0,:] == u[0])][:,1])
if bestMean < avg:
bestMean = avg
bestLot = u[0]
return bestLot
if __name__ == "__main__":
plt.xlabel("Plays")
plt.ylabel("Avg Reward")
for i in range(500):
if random.random() > eps:
choice = bestArm(av)
thisAv = np.array([[choice, reward(arms[choice])]])
av = np.concatenate((av, thisAv), axis=0)
else:
choice = np.where(arms == np.random.choice(arms))[0][0]
thisAv = np.array([[choice, reward(arms[choice])]])
av = np.concatenate((av, thisAv), axis=0)
pCorrect = len(av[np.where(av[0,:] == np.argmax(arms))])/float(len(av))
runningMean = np.mean(av[:,1])
plt.scatter(i, runningMean)