package com.example.util;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.springframework.util.CollectionUtils;
import com.alibaba.fastjson.JSON;
import com.example.domain.dto.ShapleyParam;
import com.example.domain.dto.ShapleyResult;
/**
* 夏普利值计算
* <p>Title: ShapleyUtils</p>
* <p>Description: </p>
* @author huanghz
* @date 2021年7月20日
*/
public class ShapleyUtils {
private static final String CONCAT = "@-@";
/**
* 根据策略组合值计算单个策略结果
* @param list 策略组合值,长度不允许超过8,否则2的n次方内存会爆炸
* @param scale 四舍五入保留小数位
* @return 单个策略值
*/
public static List<ShapleyResult> calculate(List<ShapleyParam> list, int scale) {
List<ShapleyResult> result = new ArrayList<>();
if (CollectionUtils.isEmpty(list)){
return result;
}
if (list.size() == 1){
ShapleyParam ShapleyParam = list.get(0);
ShapleyResult ShapleyResult = new ShapleyResult();
ShapleyResult.setValue(ShapleyParam.getValue().setScale(scale, BigDecimal.ROUND_HALF_UP));
ShapleyResult.setMember(ShapleyParam.getMembers().get(0));
result.add(ShapleyResult);
return result;
}
// 所有的参与者
List<String> participants = new ArrayList<>();
// key策略,value贡献值
Map<String, BigDecimal> valueMap = new HashMap<>();
list.forEach(ShapleyParam -> {
List<String> members = ShapleyParam.getMembers();
if (members.size() == 1){
participants.add(members.get(0));
}
valueMap.put(getKey((ShapleyParam.getMembers())), ShapleyParam.getValue());
});
// 获取所有的参与者的全组合 [A@-@B@-@C, A@-@C@-@B, B@-@A@-@C, B@-@C@-@A, C@-@A@-@B, C@-@B@-@A]
List<List<String>> fullPermutations = fullPermutation(participants);
// participants = [A,B,C]
Map<String, List<BigDecimal>> permutationValueMap = new HashMap<>();
for (String participant : participants) {
// 成员在各个组合中的贡献值
List<BigDecimal> bigDecimals = new ArrayList<>();
for (List<String> permutation : fullPermutations) {
// 成员在当前组合中的贡献值
BigDecimal permutationValue = strategyContributioValue(participant, permutation, valueMap);
bigDecimals.add(permutationValue);
}
permutationValueMap.put(participant, bigDecimals);
}
// 全组合的个数,作为计算最终贡献的分母
BigDecimal fullPermutationNumber = new BigDecimal(fullPermutations.size());
System.out.println(JSON.toJSONString(permutationValueMap));
permutationValueMap.forEach((participant, bigDecimals) -> {
// 所有可能组合的贡献值之和
BigDecimal countValue = BigDecimal.ZERO;
for (BigDecimal value : bigDecimals) {
countValue = countValue.add(value);
}
ShapleyResult ShapleyResult = new ShapleyResult();
// 所有可能组合的贡献值之和/全组合的个数
ShapleyResult.setValue(countValue.divide(fullPermutationNumber, scale, BigDecimal.ROUND_HALF_UP));
ShapleyResult.setMember(participant);
result.add(ShapleyResult);
});
return result;
}
/**
* 获取成员在全组合中的贡献值(边际值)
* @param participant 成员 如:A
* @param permutation 组合 如:[A,B,C] [B,A,C] [B,C,A]
* @return
*/
private static BigDecimal strategyContributioValue(String participant, List<String> permutation, Map<String, BigDecimal> mapValue) {
int indexOf = permutation.indexOf(participant);
// 获取边际组合
List<String> newList = new ArrayList<>(permutation.subList(0, indexOf + 1));
if (newList.size() == 1){
return mapValue.get(getKey(newList));
}
// 获取边际组合中除participant成员的组合
List<String> subList = new ArrayList<>(permutation.subList(0, indexOf));
// 获取边际组合的值
BigDecimal bigDecimal = mapValue.get(getKey(newList));
// 际组合中除participant成员的组合的值
BigDecimal subBigDecimal = mapValue.get(getKey(subList));
return bigDecimal.subtract(subBigDecimal);
}
/**
* ("A","B") -> A@-@B
*/
private static String getKey(List<String> members) {
// 排序
members.sort(String::compareTo);
return String.join(CONCAT, members);
}
/**
* 排列组合(字符不重复排列)<br>
* @param list 待排列组合字符集合(忽略重复字符)
* @return 全组合排列的字符串集合
*/
private static List<List<String>> fullPermutation(List<String> list) {
List<List<String>> permutations = new ArrayList<>();
Stream<String> stream = list.stream().distinct();
for (int n = 1; n < list.size(); n++) {
stream = stream.flatMap(i -> list.stream().filter(j -> !i.contains(j)).map(j -> i.concat(CONCAT).concat(j)));
}
List<String> collect = stream.collect(Collectors.toList());
collect.forEach(item -> {
List<String> strings = Arrays.asList(item.split(CONCAT));
permutations.add(strings);
});
return permutations;
}
public static void main(String[] args) {
List<ShapleyParam> list = JSON.parseArray(json2, ShapleyParam.class);
List<ShapleyResult> calculate = calculate(list, 1);
System.out.println(JSON.toJSONString(calculate));
}
private static final String json = "[\n" +
" {\n" +
" \"members\": [\n" +
" \"A\"\n" +
" ],\n" +
" \"value\": 100\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"B\"\n" +
" ],\n" +
" \"value\": 200\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"C\"\n" +
" ],\n" +
" \"value\": 300\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"A\",\n" +
" \"B\"\n" +
" ],\n" +
" \"value\": 500\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"B\",\n" +
" \"C\"\n" +
" ],\n" +
" \"value\": 600\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"A\",\n" +
" \"C\"\n" +
" ],\n" +
" \"value\": 700\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"A\",\n" +
" \"B\",\n" +
" \"C\"\n" +
" ],\n" +
" \"value\": 1000\n" +
" }\n" +
"]";
private static final String json2 = "[\n" +
" {\n" +
" \"members\": [\n" +
" \"A\"\n" +
" ],\n" +
" \"value\": 10\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"B\"\n" +
" ],\n" +
" \"value\": 30\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"C\"\n" +
" ],\n" +
" \"value\": 5\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"A\",\n" +
" \"B\"\n" +
" ],\n" +
" \"value\": 50\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"B\",\n" +
" \"C\"\n" +
" ],\n" +
" \"value\": 35\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"A\",\n" +
" \"C\"\n" +
" ],\n" +
" \"value\": 40\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"A\",\n" +
" \"B\",\n" +
" \"C\"\n" +
" ],\n" +
" \"value\": 100\n" +
" }\n" +
"]";
private static final String json3 = "[\n" +
" {\n" +
" \"members\": [\n" +
" \"A\"\n" +
" ],\n" +
" \"value\": 100\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"B\"\n" +
" ],\n" +
" \"value\": 125\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"C\"\n" +
" ],\n" +
" \"value\": 50\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"A\",\n" +
" \"B\"\n" +
" ],\n" +
" \"value\": 270\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"B\",\n" +
" \"C\"\n" +
" ],\n" +
" \"value\": 350\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"A\",\n" +
" \"C\"\n" +
" ],\n" +
" \"value\": 375\n" +
" },\n" +
" {\n" +
" \"members\": [\n" +
" \"A\",\n" +
" \"B\",\n" +
" \"C\"\n" +
" ],\n" +
" \"value\": 500\n" +
" }\n" +
"]";
}
Java--夏普利值计算
最新推荐文章于 2023-06-12 12:58:43 发布