余弦相似度计算原理:
余弦相似度Python代码:
C保存的是分子部分的数据。N保存的是分母部分的数据。如果下面代码看起来比较费劲,可以看java代码,可能更容易理解一些。
## ItemCF-余弦算法
import math
def ItemSimilarity_cos(train):
C = dict() ##书本对同时被购买的次数
N = dict() ##书本被购买用户数
for u,items in train.items():
for i in items.keys():
if i not in N.keys():
N[i]=0
N[i] += items[i]* items[i]
for j in items.keys():
if i == j:
continue
if i not in C.keys():
C[i]=dict()
if j not in C[i].keys():
C[i][j]=0
##当用户同时购买了i和j,则加评分乘积
C[i][j] += items[i]*items[j]
W = dict() ##书本对相似分数
for i,related_items in C.items():
if i not in W.keys():
W[i]=dict()
for j,cij in related_items.items():
W[i][j] = cij / (math.sqrt( N[i]) *math.sqrt( N[j]) )
return W
if __name__ == '__main__':
Train_Data = {'A':{'i1':1,'i2':1 ,'i4':1},
'B':{'i1':1,'i4':1},
'C':{'i1':1,'i2':1,'i5':1},
'D':{'i2':1,'i3':1},
'E':{'i3':1,'i5':1},
'F':{'i2':1,'i4':1}
}
W= ItemSimilarity_cos (Train_Data)
print(W)
结果:
{'i1': {'i2': 0.5773502691896258, 'i4': 0.6666666666666667, 'i5': 0.40824829046386296}, 'i2': {'i1': 0.5773502691896258, 'i4': 0.5773502691896258, 'i5': 0.35355339059327373, 'i3': 0.35355339059327373}, 'i4': {'i1': 0.6666666666666667, 'i2': 0.5773502691896258}, 'i5': {'i1': 0.40824829046386296, 'i2': 0.35355339059327373, 'i3': 0.4999999999999999}, 'i3': {'i2': 0.35355339059327373, 'i5': 0.4999999999999999}}
改代码来源自《推荐系统与深度学习》黄昕,赵伟,王本友——编著
因为实验需要,我又参照上述python代码,给出了Java版本,用Hashmap实现。代码如下
C的结构为<用户ID,<电影ID,评分>>
N的结构为<电影ID,电影评分的平方>
W的结构为<电影ID,<电影ID,相似度>>
Java代码
import java.util.HashMap;
import java.util.Map;
public class ItemSimilarity_cos {
public static HashMap<Integer, HashMap<Integer, Float>> C;
public static HashMap<Integer, Float> N;
public static HashMap<Integer, HashMap<Integer, Float>> W;
public static HashMap<Integer, HashMap<Integer, Float>> getC() {
return C;
}
public static HashMap<Integer, Float> getN() {
return N;
}
public static HashMap<Integer, HashMap<Integer, Float>> getW() {
return W;
}
public ItemSimilarity_cos(HashMap<Integer, HashMap<Integer, Float>> map) {
this.jisuan(map);
}
public static HashMap<Integer, HashMap<Integer, Float>> jisuan(HashMap<Integer, HashMap<Integer, Float>> map) {//map的输入是用户ID,电影ID和电影评分
C = new HashMap<Integer, HashMap<Integer, Float>>();
N = new HashMap<Integer, Float>();
W = new HashMap<Integer, HashMap<Integer, Float>>();
for (Map.Entry<Integer, HashMap<Integer, Float>> entry : map.entrySet()) {
for (Map.Entry<Integer, Float> entry1 : entry.getValue().entrySet()) {
if (!N.containsKey(entry1.getKey())) {
N.put(entry1.getKey(), 0F);
}
Float rate = N.get(entry1.getKey());
rate += entry1.getValue()*entry1.getValue();
// System.out.println(entry1.getKey()+"----------"+rate);
N.put(entry1.getKey(), rate);
for (Map.Entry<Integer, Float> entry2 : entry.getValue().entrySet()) {
if (entry1.getKey() == entry2.getKey()) {
continue;
}
if (!C.containsKey(entry1.getKey())) {
C.put(entry1.getKey(), new HashMap<Integer, Float>());
}
if (!C.get(entry1.getKey()).containsKey(entry2.getKey())) {
C.get(entry1.getKey()).put(entry2.getKey(),0F);
}
Float result= C.get(entry1.getKey()).get(entry2.getKey());
result += entry1.getValue() * entry2.getValue();
C.get(entry1.getKey()).put(entry2.getKey(),result);
}
}
}
for (Map.Entry<Integer, HashMap<Integer, Float>> entry : C.entrySet()) {
if (!W.containsKey(entry.getKey())) {
W.put(entry.getKey(), new HashMap<Integer, Float>());
}
for (Map.Entry<Integer, Float> entry1 : entry.getValue().entrySet()) {
double result = entry1.getValue() / (Math.sqrt(N.get(entry.getKey()))* Math.sqrt(N.get(entry1.getKey())));
W.get(entry.getKey()).put(entry1.getKey(), (float) result);
// System.out.println(result);
}
}
return W;
}
public static void main(String[] args) {
HashMap<Integer, HashMap<Integer, Float>> map = new HashMap<Integer, HashMap<Integer, Float>>();
map.put(0, new HashMap<Integer, Float>());
map.put(1, new HashMap<Integer, Float>());
map.put(2, new HashMap<Integer, Float>());//用户ID
map.get(0).put(1, 9.5F);
map.get(0).put(2, 3.5F);
map.get(0).put(3, 8.5F);
map.get(0).put(4, 7.0F);
map.get(1).put(1, 9.0F);
map.get(1).put(2, 5.0F);
map.get(2).put(1, 3.5F);
map.get(2).put(2, 6.0F);
map.get(2).put(4, 8.0F);
ItemSimilarity_cos cos=new ItemSimilarity_cos(map);
// System.out.println(cos.getW());
// System.out.println(cos.getC());
}
}
结果:
{1={2=0.8560688, 3=0.7013028, 4=0.6562579}, 2={1=0.8560688, 3=0.40894437, 4=0.79688376}, 3={1=0.7013028, 2=0.40894437, 4=0.6585046}, 4={1=0.6562579, 2=0.79688376, 3=0.6585046}}