import numpy as np
from collections import Counter
data=[[1,'S',-1],[1,'M',-1],[1,'M',1],[1,'S',1],[1,'S',-1],[2,'S',-1],[2,'M',-1],[2,'M',1],[2,'L',1],[2,'L',1],[3,'L',1],[3,'L',1],[3,'M',1],[3,'M',1],[3,'L',-1]]
data_new=np.array(data)
x=[]
count_x=[]
x_num=[]
x_elem=[]
for i in range(len(data_new[0,:])-1):
x.append(data_new[:,i])
count_x.append(Counter(x[i]))
x_num.append(len(count_x[i]))
x_elem.append(list(count_x[i]))
y=data_new[:,-1]
count_y=Counter(y)
y_num=len(count_y)
y_elem=list(count_y)
p_y=[]
for i in range(y_num):
p_y.append(count_y[y_elem[i]]/len(y))
p_x_y=[]
for i in range(len(x)):
p_xi_j=[]
for j in range(x_num[i]):
p_xj_y=[]
for k in range(y_num):
x_val = np.where(x[i] == x_elem[i][j])[0]
y_val = np.where(y == y_elem[k])[0]
intersect_x_y = list(set(y_val).intersection(set(x_val)))
p_temp = len(intersect_x_y) / count_y[y_elem[k]]
p_xj_y.append(p_temp)
p_xi_j.append(p_xj_y)
p_x_y.append(p_xi_j)
data_pred=[2,'S']
def predict(data_pred,p_x_y,p_y,x_elem,y_elem):
x_num=len(p_x_y)
x=[]
x_index=[]
x_elem=np.array(x_elem)
for i in range(x_num):
x.append(str(data_pred[i]))
x_index.append(np.where(x_elem[i]==x[i])[0])
p_y_pred=[]
p_y_max=0
index=1000
for i in range(len(p_y)):
p=p_y[i]
for j in range(x_num):
p*=p_x_y[j][x_index[j][0]][i]
p_y_pred.append(p)
if p>p_y_max:
p_y_max=p
index=i
return index
index=predict(data_pred,p_x_y,p_y,x_elem,y_elem)
print(y_elem[index])
《统计学习方法》第四章——朴素贝叶斯 python实现
最新推荐文章于 2024-11-15 21:18:47 发布