package cn.tworice.recommend;

import java.util.*;

/**
 * 协同过滤推荐算法
 */
public class Recommender {

    /**
     * 用户评分数据，键为用户ID，值为产品ID和评分的映射
     */
    private final Map<String, Map<Integer, Double>> userRatings;

    /**
     * 用户平均评分，键为用户ID，值为平均评分
     */
    private final Map<String, Double> userMeans;

    /**
     * 用户观看产品记录，键为用户ID，值为产品ID的集合
     */
    private final Map<String, Set<Integer>> userMovies;

    // 构造函数，初始化用户评分数据、用户平均评分和用户观看产品记录

    /**
     * 构造函数，初始化用户评分数据、用户平均评分和用户观看产品记录
     * 用户评分数据参数格式如下：<用户标识，<产品标识，产品评分>>
     * @param userRatings 用户评分数据
     */
    public Recommender(Map<String, Map<Integer, Double>> userRatings) {
        this.userRatings = userRatings;
        this.userMeans = new HashMap<>();
        this.userMovies = new HashMap<>();

        // 计算每个用户的平均评分和观看产品记录
        for (String userId : userRatings.keySet()) {
            Map<Integer, Double> ratings = userRatings.get(userId);
            double sum = 0.0;
            for (double rating : ratings.values()) {
                sum += rating;
            }
            double mean = sum / ratings.size();
            userMeans.put(userId, mean);
            userMovies.put(userId, ratings.keySet());
        }
    }

    /**
     * 计算用户之间的相似度
     */
    private double similarity(String userId1, String userId2) {
        // 获取两个用户观看产品的交集
        Set<Integer> movies1 = userMovies.get(userId1);
        if (movies1 == null || movies1.size() == 0) {
            return 0.0;
        }
        Set<Integer> movies2 = userMovies.get(userId2);
        Set<Integer> commonMovies = new HashSet<>(movies1);
        commonMovies.retainAll(movies2);
        if (commonMovies.isEmpty()) {
            return 0.0;
        }

        // 计算余弦相似度
        double sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
        for (int movieId : commonMovies) {
            double rating1 = userRatings.get(userId1).get(movieId) - userMeans.get(userId1);
            double rating2 = userRatings.get(userId2).get(movieId) - userMeans.get(userId2);
            sum1 += rating1 * rating2;
            sum2 += rating1 * rating1;
            sum3 += rating2 * rating2;
        }
        double sim = sum1 / Math.sqrt(sum2 * sum3);
        return sim;
    }

    /**
     * 基于用户的协同过滤推荐
     * @param userId 用户ID
     * @param num 返回推荐数量
     * @return 推荐标识集合
     */
    public List<Integer> recommend(String userId, int num) {
        List<Integer> recommendations = new ArrayList<>();
        Map<Integer, Double> scores = new HashMap<>();

        // 遍历所有用户，计算相似度
        for (String otherUserId : userRatings.keySet()) {
            if (otherUserId.equals(userId)) {
                continue;
            }
            double sim = similarity(userId, otherUserId);
            if (sim <= 0.0) {
                continue;
            }

            // 对于每个相似用户，计算产品评分得分
            Map<Integer, Double> ratings = userRatings.get(otherUserId);
            for (int movieId : ratings.keySet()) {
                if (userRatings.get(userId).containsKey(movieId)) {
                    continue;
                }
                double score = sim * (ratings.get(movieId) - userMeans.get(otherUserId));
                if (scores.containsKey(movieId)) {
                    scores.put(movieId, scores.get(movieId) + score);
                } else {
                    scores.put(movieId, score);
                }
            }
        }

        // 对产品评分得分进行排序，返回前num个产品ID
        List<Map.Entry<Integer, Double>> sortedScores = new ArrayList<>(scores.entrySet());
        Collections.sort(sortedScores, (a, b) -> Double.compare(b.getValue(), a.getValue()));
        for (int i = 0; i < Math.min(num, sortedScores.size()); i++) {
            recommendations.add(sortedScores.get(i).getKey());
        }
        return recommendations;
    }

    public static void main(String[] args) {
        // 构造一个用户产品评分数据集
        Map<String, Map<Integer, Double>> userRatings = new HashMap<>();
        Map<Integer, Double> user1 = new HashMap<>();
        user1.put(1, 3.0);
        user1.put(2, 4.0);
        user1.put(3, 5.0);
        user1.put(4, 3.0);
        Map<Integer, Double> user2 = new HashMap<>();
        user2.put(2, 3.0);
        user2.put(3, 4.0);
        user2.put(4, 2.0);
        user2.put(5, 5.0);
        Map<Integer, Double> user3 = new HashMap<>();
        user3.put(1, 5.0);
        user3.put(2, 4.0);
        user3.put(4, 2.0);
        user3.put(5, 3.0);
        Map<Integer, Double> user4 = new HashMap<>();
        user4.put(1, 4.0);
        user4.put(3, 5.0);
        user4.put(4, 3.0);
        user4.put(5, 4.0);
        Map<Integer, Double> user5 = new HashMap<>();
        user5.put(1, 2.0);
        user5.put(3, 3.0);
        user5.put(4, 4.0);
        user5.put(5, 5.0);
        userRatings.put("1", user1);
        userRatings.put("2", user2);
        userRatings.put("3", user3);
        userRatings.put("4", user4);
        userRatings.put("5", user5);

        // 创建一个Recommender对象，进行产品推荐
        Recommender recommender = new Recommender(userRatings);
        List<Integer> recommendations = recommender.recommend("1", 3);
        System.out.println(recommendations); // 输出 [5, 2]
    }
}