关闭

A Review on Multi-Label Learning Algorithms - KNN

标签: 机器学习
263人阅读 评论(0) 收藏 举报
分类:

在多标签分类中,有一种方法就是按照KNN的类似算法去求出每一个维度的结果。也是看周志华老师的review突然就想实现以下,然后实现了一个相当简单的。

首先我们需要进行计算的是在近邻数目为k的情况下的贝叶斯分布的可能。

也就是,首先对于每一个样本求其近邻,然后按照近邻在这一维度上的分类进行朴素贝叶斯的统计,遇到一个新样本,首先按照最近邻来计算近邻的集合,然后在每一个维度上根据其朴素贝叶斯的统计来进行计算。

下面是一个玩具版的代码实现,首先根据样本来计算贝叶斯分布的结果,同时计算近邻的集合。
有了一个实例样本之后,可以计算近邻集合并且根据近邻集合计算样本结果。

import numpy as np

def NB(X,Y,k,NN):
    NBdis = [];
    for i in range(0,Y.shape[1]):
        NBdis.append( (np.sum(Y[:,i])+1) /float(Y.shape[0]+2));

    NBtable = [];
    for i in range(0,Y.shape[1]):
        dis = np.zeros((k+1,2));
        for j in range(0,X.shape[0]):
            neighbours = NN[j];
            tmpX = np.sum(Y[neighbours,i]);
            if Y[j,i] == 0:
                dis[tmpX,1] += 1;
            else:
                dis[tmpX,0] += 1;
        smooth = 1;
        dis = dis+1 / np.sum(dis+1,axis = 1,keepdims = True);
        NBtable.append(dis);

    return (NBdis,NBtable);

def findKNN(X,k):
    NN = [];
    for x in X:
        tmpX = X.copy();
        tmpX -= x;
        tmpX = tmpX * tmpX;
        distance = np.sum(tmpX,axis = 1);
        NN.append(np.argsort(distance)[1:k+1]);
    return(NN);

def predictFindNN(X,x,k):
    tmpX = X.copy();
    tmpX -= x;
    tmpX = tmpX * tmpX;
    distance = np.sum(tmpX,axis = 1);
    return(np.argsort(distance)[0:k]);

def predictLabel(nn,NBdis,NBtable,Y):
    tmpY = Y[nn];
    tmpY = np.sum(tmpY,axis = 0);
    labels = np.zeros((1,Y.shape[1]));
    for i in range(labels.shape[1]):
        if NBdis[i]*NBtable[i][tmpY[i],0] > (1-NBdis[i])*NBtable[i][tmpY[i],1]:
            labels[0][i] = 1;
        else:
            labels[0][i] = 0;
    return labels;


X = np.array([[1,0,1,1,0],[0,1,1,1,0],[1,0,1,0,1]])
Y = np.array([[1,0,1,1],[1,0,1,0],[1,0,0,0]]);
k = 2;
NN = findKNN(X,k);
print(NN);
print('\n');
(NBdis,NBtable) = NB(X,Y,k,NN);


print('\nNaive Bayes probs');
print(NBdis);


print('\nNaive Bayes table:');
for table in NBtable:
    print(table);
    print('\n');
x = ([0,1,1,1,0]);
nn = predictFindNN(X,x,k);


print('\nNearest neighbours');
print(nn);
label = predictLabel(nn,NBdis,NBtable,Y);


print('\nLabel predict');
print(label);
0
0

查看评论
* 以上用户言论只代表其个人观点,不代表CSDN网站的观点或立场
    个人资料
    • 访问:71718次
    • 积分:2366
    • 等级:
    • 排名:第15848名
    • 原创:230篇
    • 转载:13篇
    • 译文:0篇
    • 评论:6条
    文章分类
    最新评论