matlab:
step1: function definition
F:\code_computer vision models learning and inference\cvm_recode_matlab
step2: example
F:\code_computer vision models learning and inference\cvm_recode_matlab
map_norm_ex1.m
python:
step1:function definition
C:\Users\Administrator\Algorithm.py
ps:还没有加入validate_input42()
step2:
C:\Users\Administrator\Algorithm.py\cvm_example.ipynb
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import Algorithm as mine
import scipy.stats as st #也可以使用scipy库中的相关api
%matplotlib inline
# Author:hiddry refernce fom author Stefan Stavrev 2013
# Generate random values from the normal distribution with
# mean value original_mu and standard deviation original_sig.
original_mu = 5
original_sig = 8
# I can be modified in order to see how MAP behaves for small
# vs big amount of data.
I = 100
r = st.norm(original_mu,original_sig).rvs(I)
print("type of r:",type(r))
print("r.shape:",r.shape)
print("r.ndim:",r.ndim)
print("r.min()",r.min)
print("r.max()",r.max)
# Estimate the mean and variance for the data in r.
# Values used for alpha, beta, gamma and delta are(1,1,1,0)
# for the sake of the example. Ohter values can be tried too.
estimated_mu, estimated_var = mine.A422(r,1,1,1,0)
print("Estimated mean:",estimated_mu)
print("Estimated_var:",estimated_var)
estimated_sig = np.sqrt(estimated_var)
# use algotithm 4.1 to compare.
mle_estimated_mu ,mle_estimated_var = mine.Mln(r)
mle_estimated_sig = np.sqrt(mle_estimated_var)
# Estimate and print the error for the mean and the standard deviation.MAP
muError = np.abs(original_mu - estimated_mu)
sigError = np.abs(original_sig - estimated_sig)
print("muError:",muError)
print("sigError",sigError)
# Plot the original and the estimated models for comparison.
x = np.linspace(r.min(),r.max())
plt.plot(x, st.norm(original_mu,original_sig).pdf(x),c='g',label = 'original')
plt.plot(x, st.norm(estimated_mu,estimated_sig).pdf(x),c='b',label = 'MAP')
plt.plot(x, st.norm(mle_estimated_mu,mle_estimated_sig).pdf(x),c = 'r',label = 'MLE')
plt.xlabel('x',fontsize=16)
plt.ylabel('',fontsize=16)
plt.legend(loc = 'upper right')
plt.xticks([])
plt.yticks([])
plt.title('(d)',fontsize=14)