简单的推荐算法–通过余弦相似度
gitee仓库:https://gitee.com/dobest-li/java-tools/tree/master/%E7%AE%97%E6%B3%95/%E6%8E%A8%E8%8D%90%E7%AE%97%E6%B3%95
通过余弦相似度写出一个简单的通过评分推荐数据的算法
实现效果:
环境
<!-- lombok-->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.24</version>
</dependency>
实现代码
package com.gdpi.utils;
import lombok.Data;
import java.util.*;
//实体类
@Data
class UserRatings {
Map<Integer, Double> ratings;
public UserRatings() {
this.ratings = new HashMap<>();
}
public void addRating(int itemId, double rating) {
ratings.put(itemId, rating);
}
public double getRating(int itemId) {
return ratings.getOrDefault(itemId, 0.0);
}
public Set<Integer> getItemIds() {
return ratings.keySet();
}
public double[] toArray() {
double[] array = new double[ratings.size()];
int i = 0;
for (double rating : ratings.values()) {
array[i++] = rating;
}
return array;
}
}
class RecommenderSystem {
private Map<Integer, UserRatings> userRatings;
//预测值数据存放点
Map<Integer, Double> predictions= new HashMap<>();;
public RecommenderSystem() {
this.userRatings = new HashMap<>();
}
//添加评分
public void addUserRatings(int userId, Map<Integer, Double> ratings) {
System.out.println("addUserRatings--->userId--->"+userId);
UserRatings userRatingsObj = new UserRatings();
userRatingsObj.ratings.putAll(ratings);
this.userRatings.put(userId, userRatingsObj);
}
//相似度计算
private double cosineSimilarity(UserRatings user1, UserRatings user2) {
// System.out.println("cosineSimilarity--->user1--->"+user1);
// System.out.println("cosineSimilarity--->user2--->"+user2);
Set<Integer> commonItems = new HashSet<>(user1.getItemIds());
commonItems.retainAll(user2.getItemIds());
// System.out.println("cosineSimilarity--->commonItems--->"+commonItems);
if (commonItems.isEmpty()) {
return 0.0;
}
double[] vector1 = new double[commonItems.size()];
double[] vector2 = new double[commonItems.size()];
int i = 0;
for (int itemId : commonItems) {
vector1[i] = user1.getRating(itemId);
vector2[i++] = user2.getRating(itemId);
}
//初始化数据
double dotProduct = 0.0;
double norm1 = 0.0;
double norm2 = 0.0;
for (int j = 0; j < vector1.length; j++) {
//这里是点积
dotProduct += vector1[j] * vector2[j];
//这里是取模运算
norm1 += Math.pow(vector1[j], 2);
norm2 += Math.pow(vector2[j], 2);
}
// System.out.println("dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2))-->"+dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2)));
//相似度计算
return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
}
public List<Map.Entry<Integer, Double>> generateRecommendations(int targetUserId, int numRecommendations) {
//获取推荐目标的评分
UserRatings targetUserRatings = userRatings.get(targetUserId);
//如果评分是空
if (targetUserRatings == null) {
throw new IllegalArgumentException("User not found");
}
// System.out.println("generateRecommendations------targetUserRatings->"+targetUserRatings);
//这个用户所有的数据
for (Map.Entry<Integer, UserRatings> entry : userRatings.entrySet()) {
int similarUserId = entry.getKey();
if (similarUserId == targetUserId) {
continue; // 跳过目标用户本身
}
//与其他用户的相似度计算
double similarity = cosineSimilarity(targetUserRatings, entry.getValue());
if (similarity > 0) {
UserRatings similarUserRatings = entry.getValue();
// System.out.println("generateRecommendations------similarUserRatings->"+similarUserRatings);
for (int itemId : similarUserRatings.getItemIds()) {
if (targetUserRatings.getRating(itemId) == 0) { // 目标用户未评分
//评分乘以他们两个的文件相似度,得到预测的数据
double predictedRating = similarUserRatings.getRating(itemId) * similarity;
// System.out.println("generateRecommendations------ similarUserRatings.getRating(itemId) * similarity;->"+itemId+"====222222=====?"+ similarUserRatings.getRating(itemId) +"--------"+ similarity);
// System.out.println("generateRecommendations------ similarUserRatings.getRating(itemId)->"+itemId+"=========?"+ similarUserRatings.getRating(itemId));
// System.out.println("generateRecommendations------ predictedRating->"+itemId+"=========?"+ predictedRating);
//
// System.out.println("predictions.get(itemId)=======>"+itemId+"-----"+predictions.get(itemId));
if(predictions.get(itemId)!=null){
predictions.put(itemId, (predictions.get(itemId)+predictedRating)/2);
}else{
predictions.put(itemId, predictedRating);
}
}
}
}
}
System.out.println(predictions.entrySet());
// 根据预测评分排序
List<Map.Entry<Integer, Double>> sortedPredictions = new ArrayList<>(predictions.entrySet());
sortedPredictions.sort(Map.Entry.comparingByValue(Comparator.reverseOrder()));
// 截取前numRecommendations个推荐
if (sortedPredictions.size() > numRecommendations) {
sortedPredictions = sortedPredictions.subList(0, numRecommendations);
}
return sortedPredictions;
}
public static void main(String[] args) {
RecommenderSystem recommenderSystem = new RecommenderSystem();
// 示例数据
Map<Integer, Double> user1Ratings = new HashMap<>();
user1Ratings.put(1, 5.0);
user1Ratings.put(3, 3.0);
Map<Integer, Double> user2Ratings = new HashMap<>();
user2Ratings.put(1,1.0);
user2Ratings.put(2, 4.0);
user2Ratings.put(3, 3.0);
user2Ratings.put(4, 1.0);
user2Ratings.put(6, 1.0);
Map<Integer, Double> user3Ratings = new HashMap<>();
user3Ratings.put(1, 2.0);
user3Ratings.put(2, 5.0);
user3Ratings.put(3, 2.0);
user3Ratings.put(4, 2.0);
user2Ratings.put(5, 1.0);
Map<Integer, Double> user4Ratings = new HashMap<>();
user4Ratings.put(1, 1.0);
user4Ratings.put(2, 3.0);
user4Ratings.put(3, 1.0);
user4Ratings.put(4, 2.0);
user2Ratings.put(6, 1.0);
recommenderSystem.addUserRatings(1, user1Ratings);
recommenderSystem.addUserRatings(2, user2Ratings);
recommenderSystem.addUserRatings(3, user3Ratings);
recommenderSystem.addUserRatings(4, user4Ratings);
// 为用户1生成推荐
List<Map.Entry<Integer, Double>> recommendations = recommenderSystem.generateRecommendations(1, 2);
// 打印推荐结果
for (Map.Entry<Integer, Double> entry : recommendations) {
System.out.println("Recommend item ID: " + entry.getKey() + ", Predicted rating: " + entry.getValue());
}
}
}