Java--夏普利值计算

package com.example.util;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.springframework.util.CollectionUtils;

import com.alibaba.fastjson.JSON;
import com.example.domain.dto.ShapleyParam;
import com.example.domain.dto.ShapleyResult;

/**
 * 夏普利值计算
* <p>Title: ShapleyUtils</p>  
* <p>Description: </p>  
* @author huanghz  
* @date 2021年7月20日
 */
public class ShapleyUtils {

	private static final String CONCAT = "@-@";

    /**
     * 根据策略组合值计算单个策略结果
     * @param list  策略组合值,长度不允许超过8,否则2的n次方内存会爆炸
     * @param scale 四舍五入保留小数位
     * @return 单个策略值
     */
    public static List<ShapleyResult> calculate(List<ShapleyParam> list, int scale) {
        List<ShapleyResult> result = new ArrayList<>();
        if (CollectionUtils.isEmpty(list)){
            return result;
        }
        if (list.size() == 1){
            ShapleyParam ShapleyParam = list.get(0);
            ShapleyResult ShapleyResult = new ShapleyResult();
            ShapleyResult.setValue(ShapleyParam.getValue().setScale(scale, BigDecimal.ROUND_HALF_UP));
            ShapleyResult.setMember(ShapleyParam.getMembers().get(0));
            result.add(ShapleyResult);
            return result;
        }
        // 所有的参与者
        List<String> participants = new ArrayList<>();
        // key策略,value贡献值
        Map<String, BigDecimal> valueMap = new HashMap<>();
        list.forEach(ShapleyParam -> {
            List<String> members = ShapleyParam.getMembers();
            if (members.size() == 1){
                participants.add(members.get(0));
            }
            valueMap.put(getKey((ShapleyParam.getMembers())), ShapleyParam.getValue());
        });
        // 获取所有的参与者的全组合 [A@-@B@-@C, A@-@C@-@B, B@-@A@-@C, B@-@C@-@A, C@-@A@-@B, C@-@B@-@A]
        List<List<String>> fullPermutations = fullPermutation(participants);
        // participants = [A,B,C]
        Map<String, List<BigDecimal>> permutationValueMap = new HashMap<>();
        for (String participant : participants) {
            // 成员在各个组合中的贡献值
            List<BigDecimal> bigDecimals = new ArrayList<>();
            for (List<String> permutation : fullPermutations) {
                // 成员在当前组合中的贡献值
                BigDecimal permutationValue = strategyContributioValue(participant, permutation, valueMap);
                bigDecimals.add(permutationValue);
            }
            permutationValueMap.put(participant, bigDecimals);
        }

        // 全组合的个数,作为计算最终贡献的分母
        BigDecimal fullPermutationNumber = new BigDecimal(fullPermutations.size());

        System.out.println(JSON.toJSONString(permutationValueMap));
        permutationValueMap.forEach((participant, bigDecimals) -> {
            // 所有可能组合的贡献值之和
            BigDecimal countValue = BigDecimal.ZERO;
            for (BigDecimal value : bigDecimals) {
                countValue = countValue.add(value);
            }
            ShapleyResult ShapleyResult = new ShapleyResult();
            // 所有可能组合的贡献值之和/全组合的个数
            ShapleyResult.setValue(countValue.divide(fullPermutationNumber, scale, BigDecimal.ROUND_HALF_UP));
            ShapleyResult.setMember(participant);
            result.add(ShapleyResult);
        });
        return result;
    }
    
    /**
     * 获取成员在全组合中的贡献值(边际值)
     * @param participant 成员 如:A
     * @param permutation 组合 如:[A,B,C]  [B,A,C]  [B,C,A]
     * @return
     */
    private static BigDecimal strategyContributioValue(String participant, List<String> permutation, Map<String, BigDecimal> mapValue) {
        int indexOf = permutation.indexOf(participant);
        // 获取边际组合
        List<String> newList = new ArrayList<>(permutation.subList(0, indexOf + 1));
        if (newList.size() == 1){
            return mapValue.get(getKey(newList));
        }

        // 获取边际组合中除participant成员的组合
        List<String> subList = new ArrayList<>(permutation.subList(0, indexOf));

        // 获取边际组合的值
        BigDecimal bigDecimal = mapValue.get(getKey(newList));

        // 际组合中除participant成员的组合的值
        BigDecimal subBigDecimal = mapValue.get(getKey(subList));

        return bigDecimal.subtract(subBigDecimal);
    }

    /**
     * ("A","B") -> A@-@B
     */
    private static String getKey(List<String> members) {
        // 排序
        members.sort(String::compareTo);
        return String.join(CONCAT, members);
    }

    /**
     * 排列组合(字符不重复排列)<br>
     * @param list 待排列组合字符集合(忽略重复字符)
     * @return 全组合排列的字符串集合
     */
    private static List<List<String>> fullPermutation(List<String> list) {
        List<List<String>> permutations = new ArrayList<>();
        Stream<String> stream = list.stream().distinct();
        for (int n = 1; n < list.size(); n++) {
            stream = stream.flatMap(i -> list.stream().filter(j -> !i.contains(j)).map(j -> i.concat(CONCAT).concat(j)));
        }
        List<String> collect = stream.collect(Collectors.toList());
        collect.forEach(item -> {
            List<String> strings = Arrays.asList(item.split(CONCAT));
            permutations.add(strings);
        });
        return permutations;
    }
    
    public static void main(String[] args) {
        List<ShapleyParam> list = JSON.parseArray(json2, ShapleyParam.class);

        List<ShapleyResult> calculate = calculate(list, 1);

        System.out.println(JSON.toJSONString(calculate));

    }
    
    private static final String json = "[\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"A\"\n" +
            "        ],\n" +
            "        \"value\": 100\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"B\"\n" +
            "        ],\n" +
            "        \"value\": 200\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"C\"\n" +
            "        ],\n" +
            "        \"value\": 300\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"A\",\n" +
            "            \"B\"\n" +
            "        ],\n" +
            "        \"value\": 500\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"B\",\n" +
            "            \"C\"\n" +
            "        ],\n" +
            "        \"value\": 600\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"A\",\n" +
            "            \"C\"\n" +
            "        ],\n" +
            "        \"value\": 700\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"A\",\n" +
            "            \"B\",\n" +
            "            \"C\"\n" +
            "        ],\n" +
            "        \"value\": 1000\n" +
            "    }\n" +
            "]";
    
    private static final String json2 = "[\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"A\"\n" +
            "        ],\n" +
            "        \"value\": 10\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"B\"\n" +
            "        ],\n" +
            "        \"value\": 30\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"C\"\n" +
            "        ],\n" +
            "        \"value\": 5\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"A\",\n" +
            "            \"B\"\n" +
            "        ],\n" +
            "        \"value\": 50\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"B\",\n" +
            "            \"C\"\n" +
            "        ],\n" +
            "        \"value\": 35\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"A\",\n" +
            "            \"C\"\n" +
            "        ],\n" +
            "        \"value\": 40\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"A\",\n" +
            "            \"B\",\n" +
            "            \"C\"\n" +
            "        ],\n" +
            "        \"value\": 100\n" +
            "    }\n" +
            "]";
    
    private static final String json3 = "[\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"A\"\n" +
            "        ],\n" +
            "        \"value\": 100\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"B\"\n" +
            "        ],\n" +
            "        \"value\": 125\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"C\"\n" +
            "        ],\n" +
            "        \"value\": 50\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"A\",\n" +
            "            \"B\"\n" +
            "        ],\n" +
            "        \"value\": 270\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"B\",\n" +
            "            \"C\"\n" +
            "        ],\n" +
            "        \"value\": 350\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"A\",\n" +
            "            \"C\"\n" +
            "        ],\n" +
            "        \"value\": 375\n" +
            "    },\n" +
            "    {\n" +
            "        \"members\": [\n" +
            "            \"A\",\n" +
            "            \"B\",\n" +
            "            \"C\"\n" +
            "        ],\n" +
            "        \"value\": 500\n" +
            "    }\n" +
            "]";
}

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值