第一问
package org.lenskit.mooc.nonpers.Imp;
/*calculate the highest average rating
*/
import org.lenskit.mooc.nonpers.Util.MapSortByValue;
import org.lenskit.mooc.nonpers.Util.Rating;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.*;
public class HighestAvgRating {
public static void main(String[] args) throws Exception {
ArrayList userId = new ArrayList();
ArrayList movieId = new ArrayList();
ArrayList rating = new ArrayList();
File file = new File("data/ratings.csv");
List<String> lines = Files.readAllLines(file.toPath(),
StandardCharsets.UTF_8);
for (String line : lines) {
String[] array = line.split(",");
userId.add(array[0]);
movieId.add(array[1]);
rating.add(array[2]);
}
userId.remove(0);
movieId.remove(0);
rating.remove(0);
int count = 1;
Double avg = null;
List<Rating> list = new ArrayList<Rating>();
HashMap<Long, Double> mean = new HashMap<>();
for (int i = 0; i < userId.size(); i++) {
list.add(new Rating(Long.parseLong(movieId.get(i).toString()),
Double.parseDouble(rating.get(i).toString())));
}
for (int i = 0; i < list.size(); i++) {
count = 1;
Long tmp = list.get(i).getMovieid();
Double k = list.get(i).getRatings();
for (int j = i + 1; j < list.size(); j++) {
if (tmp.equals(list.get(j).getMovieid())) {
count++;
k += list.get(j).getRatings();
list.remove(list.get(j));
j--;
}
}
avg = k / count;
mean.put(list.get(i).getMovieid(), avg);
}
for (Long key : MapSortByValue.sortByValue(mean).keySet()) {
System.out.print(key + "-" + MapSortByValue.sortByValue(mean).get(key) + "\n");
}
}
}
Utils:
package org.lenskit.mooc.nonpers.Util;
import java.util.*;
public class MapSortByValue {
//对map集合进行降序排序
public static <K, V extends Comparable<? super V>> Map<K, V> sortByValue(Map<K, V> map) {
List<Map.Entry<K, V>> list = new LinkedList<Map.Entry<K, V>>(map.entrySet());
Collections.sort(list, new Comparator<Map.Entry<K, V>>()
{
@Override
public int compare(Map.Entry<K, V> o1, Map.Entry<K, V> o2)
{
int compare = (o1.getValue()).compareTo(o2.getValue());
return -compare;
}
});
Map<K, V> result = new LinkedHashMap<K, V>();
for (Map.Entry<K, V> entry : list) {
result.put(entry.getKey(), entry.getValue());
}
return result;
}
}
package org.lenskit.mooc.nonpers.Util;
public class Rating {
private Long movieid;
private Double ratings;
public Rating(){}
public Rating(Long movieid,Double ratings){
this.movieid = movieid;
this.ratings = ratings;
}
public Long getMovieid() {
return movieid;
}
public void setMovieid(Long movieid) {
this.movieid = movieid;
}
public Double getRatings() {
return ratings;
}
public void setRatings(Double ratings) {
this.ratings = ratings;
}
}
最终部分输出:
第二问
package org.lenskit.mooc.nonpers.Imp;
import org.lenskit.mooc.nonpers.Util.MapSortByValue;
import org.lenskit.mooc.nonpers.Util.Rating;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class DampedMeanRating {
//set damping factor =5
private static final double DAMPINGFACTOR = 5.0;
public static <Dobule> void main(String[] args) throws Exception {
ArrayList userId = new ArrayList();
ArrayList movieId = new ArrayList();
ArrayList rating = new ArrayList();
File file = new File("data/ratings.csv");
List<String> lines = Files.readAllLines(file.toPath(),
StandardCharsets.UTF_8);
for (String line : lines) {
String[] array = line.split(",");
userId.add(array[0]);
movieId.add(array[1]);
rating.add(array[2]);
}
userId.remove(0);
movieId.remove(0);
rating.remove(0);
int count = 0;
Double sum = 0.0;
List<Rating> list = new ArrayList<>();
HashMap<Long, Double> mean = new HashMap<>();
for (int i = 0; i < userId.size(); i++) {
list.add(new Rating(Long.parseLong(movieId.get(i).toString()),
Double.parseDouble(rating.get(i).toString())));
}
for (int i = 0; i < list.size(); i++) {
count++;
sum += Double.parseDouble(list.get(i).getRatings().toString());
}
//calculate global mean
Double global_avg = sum/count;
DampedAvgRating(global_avg, DAMPINGFACTOR);
}
//calculate damped mean
public static Map DampedAvgRating(Double global_avg, Double DAMPINGFACTOR) throws IOException {
ArrayList userId = new ArrayList();
ArrayList movieId = new ArrayList();
ArrayList rating = new ArrayList();
File file = new File("data/ratings.csv");
List<String> lines = Files.readAllLines(file.toPath(),
StandardCharsets.UTF_8);
for (String line : lines) {
String[] array = line.split(",");
userId.add(array[0]);
movieId.add(array[1]);
rating.add(array[2]);
}
userId.remove(0);
movieId.remove(0);
rating.remove(0);
int count = 1;
Double avg = null;
List<Rating> list = new ArrayList<Rating>();
HashMap<Long, Double> mean = new HashMap<>();
for (int i = 0; i < userId.size(); i++) {
list.add(new Rating(Long.parseLong(movieId.get(i).toString()),
Double.parseDouble(rating.get(i).toString())));
}
for (int i = 0; i < list.size(); i++) {
count = 1;
Long tmp = list.get(i).getMovieid();
Double k = list.get(i).getRatings();
for (int j = i + 1; j < list.size(); j++) {
if (tmp.equals(list.get(j).getMovieid())) {
count++;
k += list.get(j).getRatings();
list.remove(list.get(j));
//remove一个元素时,要把遍历的指针减一
j--;
}
}
avg = (k+(DAMPINGFACTOR*global_avg)) / (count+DAMPINGFACTOR);
mean.put(list.get(i).getMovieid(), avg);
}
for (Long key : MapSortByValue.sortByValue(mean).keySet()) {
System.out.print(key + "-" + MapSortByValue.sortByValue(mean).get(key) + "\n");
}
return mean;
}
}
部分output