思维导图
- 离散贝叶斯原理示例
- 实例展示
原理示例
- 假设训练集里有两类,一类叫做苹果,一类叫做香蕉
- 训练集中有两个特征,一个是颜色(取值为黄色和红色),一个是长度(取值为长和短)
- 现在给定测试集数据特征 (黄色,长),判断是什么水果
分析与解答
- 先计算出先验概率 P ( y ) P(y) P(y),
- P ( y = P(y= P(y=苹果 ) = 训 练 集 中 的 苹 果 数 训 练 集 中 总 数 )=\frac{训练集中的苹果数}{训练集中总数} )=训练集中总数训练集中的苹果数, P ( y = P(y= P(y=香蕉 ) = 训 练 集 中 的 香 蕉 数 训 练 集 中 总 数 )=\frac{训练集中的香蕉数}{训练集中总数} )=训练集中总数训练集中的香蕉数
- 根据离散的朴素贝叶斯公式,后验概率
P ( y ∣ x ⃗ ) = P ( x ⃗ ∣ y ) P ( y ) P ( x ⃗ ) P(y|\vec{x})=\frac{P(\vec{x}|y)P(y)}{P(\vec{x})} P(y∣x)=P(x)P(x∣y)P(y) - 由于假设条件中设定不同的特征之间相互独立,所以
P ( x ⃗ ∣ y ) = P ( x 1 = v ∣ y ) P ( x 2 = u ∣ y ) P(\vec{x}|y)=P(x_{1}=v|y)P(x_{2}=u|y) P(x∣y)=P(x1=v∣y)P(x2=u∣y)其中, x 1 x_{1} x1为颜色特征, x 2 x_{2} x2为长度特征。 v ∈ { v\in\{ v∈{黄色,红色 } \} } u ∈ { u\in\{ u∈{长,短 } \} } - 故而,拿到特征 (黄色,长)之后,需要先计算以下两个式子
P ( y = 苹 果 ∣ x ⃗ ) = P ( x 1 = 黄 色 ∣ y = 苹 果 ) P ( x 2 = 长 ∣ y = 苹 果 ) P ( y = 苹 果 ) P ( x ⃗ ) P(y=苹果|\vec{x})=\frac{P(x_1=黄色|y=苹果)P(x_2=长|y=苹果)P(y=苹果)}{P(\vec{x})} P(y=苹果∣x)=P(x)P(x1=黄色∣y=苹果)P(x2=长∣y=苹果)P(y=苹果)
P ( y = 香 蕉 ∣ x ⃗ ) = P ( x 1 = 黄 色 ∣ y = 香 蕉 ) P ( x 2 = 长 ∣ y = 香 蕉 ) P ( y = 香 蕉 ) P ( x ⃗ ) P(y=香蕉|\vec{x})=\frac{P(x_1=黄色|y=香蕉)P(x_2=长|y=香蕉)P(y=香蕉)}{P(\vec{x})} P(y=香蕉∣x)=P(x)P(x1=黄色∣y=香蕉)P(x2=长∣y=香蕉)P(y=香蕉) - 由于 arg max P ( y ∣ x ⃗ ) = arg max P ( x ⃗ ∣ y ) P ( y ) \arg \max{P(y|\vec{x})}=\arg\max{P(\vec{x}|y)P(y)} argmaxP(y∣x)=argmaxP(x∣y)P(y),故而上边两个式子计算出的条件概率大,就取哪一类。
代码展示
/**
* 二维数组转置
* @function transposeIris
* @param {Array} arr
* @returns {Array}
*/
let transposeIris = (arr) => {
return arr[0].map((col, i) => {
return arr.map((row) => {
return row[i];
});
});
};
/**
* 统计特征值的个数
* @function countFeatures
* @param {*} irisFeatures
* @param {*} ele
* @returns {*}
*/
let countFeatures = (irisFeatures, ele) => {
countVec = {};
//数组去重,便于统计和拉普拉斯平滑
let irisRemDul = irisFeatures.filter(function (item, index, arr) {
return irisFeatures.indexOf(item, 0) === index;
});
//初始化
irisRemDul.forEach((i) => {
countVec[i] = 0;
});
irisFeatures.forEach((f) => {
countVec[f] += 1;
});
return countVec[ele.toString()] == undefined
? 1 / (50 + Object.keys(countVec).length)
: countVec[ele.toString()] / 50;
};
/**
* reduce的累乘函数
* @function nn
* @param {*} x
* @param {*} y
*/
let nn = (x, y) => {
return x * y;
};
/**
* 输出计算的后验概率
* @param {*} featurMatrix
* @param {*} prioriVec
* @param {*} categoryNums
* @returns {Array}
*/
let maxProbCategory = (featurMatrix, prioriVec, categoryNums) => {
let stepProb = {};
for (let i = 0; i < categoryNums; i++) {
stepProb[i] = 0;
}
for (let i = 0; i < categoryNums; i++) {
stepProb[i] = prioriVec[i] * Object.values(featurMatrix[i]).reduce(nn, 1);
}
return stepProb;
};
/**
* 判断类别
* @function judgeCategory
* @param {*} priorProb
* @param {*} features
* @param {*} categoryNums
* @returns {*}
*/
let judgeCategory = (priorProb, features, categoryNums) => {
if (!Array.isArray(features) || features.length !== 4) return false;
// transpose the features martixs to count easily
let irisTranspose = { 0: [], 1: [], 2: [] };
for (let i = 0; i < categoryNums; i++) {
irisTranspose[i] = transposeIris(irisTrains[i]);
}
//calculate P(x|y)
let cateFeatures = {};
for (let k = 0; k < categoryNums; k++) {
cateFeatures[k] = {};
for (let t = 0; t < features.length; t++) {
cateFeatures[k][features[t]] = 0;
}
}
for (let k = 0; k < categoryNums; k++) {
for (let t = 0; t < features.length; t++) {
//在训练集中查找第k类,第t个特征数组中为features[t]的个数,如果为0,采用拉普拉斯平滑
cateFeatures[k][features[t]] = countFeatures(
irisTranspose[k][t],
features[t]
);
}
}
return maxProbCategory(cateFeatures, priorProb, categoryNums);
};
//calculate the priori probs P(y)
/**
* @function genPrioriProb
* @param {*} categoryNums
* @param {*} irisTrains
* @returns {*}
*/
let genPrioriProb = (categoryNums, irisTrains) => {
let priorProb = {};
for (let j = 0; j < categoryNums; j++) {
priorProb[j] = irisTrains[j].length / 150;
}
return priorProb;
};
const irisTrains = {
0: [ [5.1, 3.5, 1.4, 0.2], [4.9, 3, 1.4, 0.2], [4.7, 3.2, 1.3, 0.2], [4.6, 3.1, 1.5, 0.2], [5, 3.6, 1.4, 0.2], [5.4, 3.9, 1.7, 0.4], [4.6, 3.4, 1.4, 0.3], [5, 3.4, 1.5, 0.2], [4.4, 2.9, 1.4, 0.2], [4.9, 3.1, 1.5, 0.1], [5.4, 3.7, 1.5, 0.2], [4.8, 3.4, 1.6, 0.2], [4.8, 3, 1.4, 0.1], [4.3, 3, 1.1, 0.1], [5.8, 4, 1.2, 0.2], [5.7, 4.4, 1.5, 0.4], [5.4, 3.9, 1.3, 0.4], [5.1, 3.5, 1.4, 0.3], [5.7, 3.8, 1.7, 0.3], [5.1, 3.8, 1.5, 0.3], [5.4, 3.4, 1.7, 0.2], [5.1, 3.7, 1.5, 0.4], [4.6, 3.6, 1, 0.2], [5.1, 3.3, 1.7, 0.5], [4.8, 3.4, 1.9, 0.2], [5, 3, 1.6, 0.2], [5, 3.4, 1.6, 0.4], [5.2, 3.5, 1.5, 0.2], [5.2, 3.4, 1.4, 0.2], [4.7, 3.2, 1.6, 0.2], [4.8, 3.1, 1.6, 0.2], [5.4, 3.4, 1.5, 0.4], [5.2, 4.1, 1.5, 0.1], [5.5, 4.2, 1.4, 0.2], [4.9, 3.1, 1.5, 0.2], [5, 3.2, 1.2, 0.2], [5.5, 3.5, 1.3, 0.2], [4.9, 3.6, 1.4, 0.1], [4.4, 3, 1.3, 0.2], [5.1, 3.4, 1.5, 0.2], [5, 3.5, 1.3, 0.3], [4.5, 2.3, 1.3, 0.3], [4.4, 3.2, 1.3, 0.2], [5, 3.5, 1.6, 0.6], [5.1, 3.8, 1.9, 0.4], [4.8, 3, 1.4, 0.3], [5.1, 3.8, 1.6, 0.2], [4.6, 3.2, 1.4, 0.2], [5.3, 3.7, 1.5, 0.2], [5, 3.3, 1.4, 0.2],
],
1: [ [7, 3.2, 4.7, 1.4], [6.4, 3.2, 4.5, 1.5], [6.9, 3.1, 4.9, 1.5], [5.5, 2.3, 4, 1.3], [6.5, 2.8, 4.6, 1.5], [5.7, 2.8, 4.5, 1.3], [6.3, 3.3, 4.7, 1.6], [4.9, 2.4, 3.3, 1], [6.6, 2.9, 4.6, 1.3], [5.2, 2.7, 3.9, 1.4], [5, 2, 3.5, 1], [5.9, 3, 4.2, 1.5], [6, 2.2, 4, 1], [6.1, 2.9, 4.7, 1.4], [5.6, 2.9, 3.6, 1.3], [6.7, 3.1, 4.4, 1.4], [5.6, 3, 4.5, 1.5], [5.8, 2.7, 4.1, 1], [6.2, 2.2, 4.5, 1.5], [5.6, 2.5, 3.9, 1.1], [5.9, 3.2, 4.8, 1.8], [6.1, 2.8, 4, 1.3], [6.3, 2.5, 4.9, 1.5], [6.1, 2.8, 4.7, 1.2], [6.4, 2.9, 4.3, 1.3], [6.6, 3, 4.4, 1.4], [6.8, 2.8, 4.8, 1.4], [6.7, 3, 5, 1.7], [6, 2.9, 4.5, 1.5], [5.7, 2.6, 3.5, 1], [5.5, 2.4, 3.8, 1.1], [5.5, 2.4, 3.7, 1], [5.8, 2.7, 3.9, 1.2], [6, 2.7, 5.1, 1.6], [5.4, 3, 4.5, 1.5], [6, 3.4, 4.5, 1.6], [6.7, 3.1, 4.7, 1.5], [6.3, 2.3, 4.4, 1.3], [5.6, 3, 4.1, 1.3], [5.5, 2.5, 4, 1.3], [5.5, 2.6, 4.4, 1.2], [6.1, 3, 4.6, 1.4], [5.8, 2.6, 4, 1.2], [5, 2.3, 3.3, 1], [5.6, 2.7, 4.2, 1.3], [5.7, 3, 4.2, 1.2], [5.7, 2.9, 4.2, 1.3], [6.2, 2.9, 4.3, 1.3], [5.1, 2.5, 3, 1.1], [5.7, 2.8, 4.1, 1.3],
],
2: [ [6.3, 3.3, 6, 2.5], [5.8, 2.7, 5.1, 1.9], [7.1, 3, 5.9, 2.1], [6.3, 2.9, 5.6, 1.8], [6.5, 3, 5.8, 2.2], [7.6, 3, 6.6, 2.1], [4.9, 2.5, 4.5, 1.7], [7.3, 2.9, 6.3, 1.8], [6.7, 2.5, 5.8, 1.8], [7.2, 3.6, 6.1, 2.5], [6.5, 3.2, 5.1, 2], [6.4, 2.7, 5.3, 1.9], [6.8, 3, 5.5, 2.1], [5.7, 2.5, 5, 2], [5.8, 2.8, 5.1, 2.4], [6.4, 3.2, 5.3, 2.3], [6.5, 3, 5.5, 1.8], [7.7, 3.8, 6.7, 2.2], [7.7, 2.6, 6.9, 2.3], [6, 2.2, 5, 1.5], [6.9, 3.2, 5.7, 2.3], [5.6, 2.8, 4.9, 2], [7.7, 2.8, 6.7, 2], [6.3, 2.7, 4.9, 1.8], [6.7, 3.3, 5.7, 2.1], [7.2, 3.2, 6, 1.8], [6.2, 2.8, 4.8, 1.8], [6.1, 3, 4.9, 1.8], [6.4, 2.8, 5.6, 2.1], [7.2, 3, 5.8, 1.6], [7.4, 2.8, 6.1, 1.9], [7.9, 3.8, 6.4, 2], [6.4, 2.8, 5.6, 2.2], [6.3, 2.8, 5.1, 1.5], [6.1, 2.6, 5.6, 1.4], [7.7, 3, 6.1, 2.3], [6.3, 3.4, 5.6, 2.4], [6.4, 3.1, 5.5, 1.8], [6, 3, 4.8, 1.8], [6.9, 3.1, 5.4, 2.1], [6.7, 3.1, 5.6, 2.4], [6.9, 3.1, 5.1, 2.3], [5.8, 2.7, 5.1, 1.9], [6.8, 3.2, 5.9, 2.3], [6.7, 3.3, 5.7, 2.5], [6.7, 3, 5.2, 2.3], [6.3, 2.5, 5, 1.9], [6.5, 3, 5.2, 2], [6.2, 3.4, 5.4, 2.3], [5.9, 3, 5.1, 1.8],
],
};
const categoryNums = Object.keys(irisTrains).length;
const priorProb = genPrioriProb(categoryNums, irisTrains);
judgeCategory(priorProb, [5.8, 2.6, 4, 1.2], categoryNums);
- 结果展示
{
'0': 3.0572064470369556e-8,
'1': 0.000012,
'2': 1.8433179723502302e-7
}