#KDD-98 experiment:using a synthetic biased sample
#(the selection probability is known)to estimate
#the generalization error of a hypothesis
import numpy as np
import random
import scipy.stats as stats
import matplotlib.pyplot as plt
#imitate KDD-98 dataset,assume it has 10000 records
m = 10000
X,loss = np.empty(m), np.empty(m)
for i in range(m):
X[i] = random.randint(0,7)
#directly create the loss l(h(x),y) for each X[i]
loss[i] = 1 + X[i] * 5 + random.gauss(0,4)
true_err = loss.mean()
print 'the estimated generalization error via unbiased sample :',true_err
#generate n biased samples,i.e. n biased testing sets
n = 1000
uncorrected_loss,corrected_loss = np.empty(n),np.empty(n)
for k in range(n):
X_biased,loss_biased = [],[]
for i in range(m):
if(X[i] in [0,1,2,3]):
s = stats.bernoulli.rvs(0.3)
else:#X[i] in [4,5,6,7]
s = stats.bernoulli.rvs(0.9)
if(1 == s):
X_biased.append(X[i])
loss_biased.append(loss[i])
uncorrected_loss[k] = np.mean(loss_biased)
m_biased = len(X_biased)
overall_selection_prob = 1.0 * m_biased / m
weighted_loss = np.empty(m_biased)
for i in range(m_biased):
if(X_biased[i] in [0,1,2,3]):
weighted_loss[i] = overall_selection_prob * loss_biased[i] / 0.3
else:
weighted_loss[i] = overall_selection_prob * loss_biased[i] / 0.9
corrected_loss[k] = weighted_loss.mean()
#obtain the bin values and the corresponding frequency values
fre, bins, patches = plt.hist(x = uncorrected_loss, bins = 10,normed = False)
frequency = fre / sum(fre)
fre1, bins1, patches1 = plt.hist(x = corrected_loss, bins = 10,normed = False)
frequency1 = fre1 / sum(fre1)
#show the middle graph which is useless
plt.show()
#=================================================================================
#using scalar true_err, numpy arrays uncorrected_loss and corrected_loss to make
#the histogram plot
fig,ax = plt.subplots(nrows = 1, ncols = 2)
num_bins = 10
#plot the histogram for the uncorrected_loss,i.e. the generalization error estimated
#by biased sample
#do the bar plot
ax[0].bar(left = bins[:-1], height = frequency, width=0.05)
ax[0].plot([true_err] * 5,np.linspace(0,frequency.max() + 0.1,5,endpoint=True), 'r--')
ax[0].set_xlabel('Uncorrected estimate of the generalization error')
ax[0].set_ylabel('Frequency')
#plot the histogram for the corrected_loss,i.e. the generalization error estimated
#by a weighted sample
ax[1].bar(left = bins1[:-1], height = frequency1, width=0.05)
ax[1].plot([true_err] * 5,np.linspace(0,frequency1.max() + 0.1,5,endpoint=True), 'r--')
ax[1].set_xlabel('Corrected estimate of the generalization error')
ax[1].set_ylabel('Frequency')
# Tweak spacing to prevent clipping of ylabel
fig.tight_layout()
#show the target graph
plt.show()