mpc4j--pir计算流程

服务端pir

核心计算方法

computeResponse

查询参数

  - query_params: {
     - ps_low_degree : 44
     - query_powers : [1, 3, 11, 18, 45, 225]
  }

encryptedQuery size: 24
    /**
     * generate response.
     *
     * @param encryptedQuery client query.
     * @throws MpcAbortException the protocol failure aborts.
     */
    private void computeResponse(List<byte[]> encryptedQuery) throws MpcAbortException {
        //从redis中取binSize和labelPartitionCount
        this.binSize = Integer.parseInt(stringRedisUtil.getStringValue(BIN_SIZE));
        int labelPartitionCount = Integer.parseInt(stringRedisUtil.getStringValue(LABEL_PARTITION_COUNT));

        int partitionCount = CommonUtils.getUnitNum(binSize, params.getMaxPartitionSizePerBin());
        int[][] powerDegree;
        if (params.getPsLowDegree() > 0) {
            Set<Integer> innerPowersSet = new HashSet<>();
            Set<Integer> outerPowersSet = new HashSet<>();
            for (int i = 0; i < params.getQueryPowers().length; i++) {
                if (params.getQueryPowers()[i] <= params.getPsLowDegree()) {
                    innerPowersSet.add(params.getQueryPowers()[i]);
                } else {
                    outerPowersSet.add(params.getQueryPowers()[i] / (params.getPsLowDegree() + 1));
                }
            }
            System.out.println("innerPowersSet: " + innerPowersSet);
            System.out.println("outerPowersSet: " + outerPowersSet);

            PowerNode[] innerPowerNodes = PowerUtils.computePowers(innerPowersSet, params.getPsLowDegree());
            PowerNode[] outerPowerNodes = PowerUtils.computePowers(
                    outerPowersSet, params.getMaxPartitionSizePerBin() / (params.getPsLowDegree() + 1));
            System.out.println("innerPowerNodes: " + JSON.toJSONString(innerPowerNodes)+"--size: "+innerPowerNodes.length);
            System.out.println("outerPowerNodes: " + JSON.toJSONString(outerPowerNodes)+"--size: "+outerPowerNodes.length);

            powerDegree = new int[innerPowerNodes.length + outerPowerNodes.length][2];
            int[][] innerPowerNodesDegree = Arrays.stream(innerPowerNodes).map(PowerNode::toIntArray).toArray(int[][]::new);
            int[][] outerPowerNodesDegree = Arrays.stream(outerPowerNodes).map(PowerNode::toIntArray).toArray(int[][]::new);
            System.arraycopy(innerPowerNodesDegree, 0, powerDegree, 0, innerPowerNodesDegree.length);
            System.arraycopy(outerPowerNodesDegree, 0, powerDegree, innerPowerNodesDegree.length, outerPowerNodesDegree.length);

            System.out.println("innerPowerNodesDegree: ");
            for (int[] degree : innerPowerNodesDegree) {
                System.out.println(Arrays.toString(degree));
            }
            System.out.println("outerPowerNodesDegree: ");
            for (int[] degree : outerPowerNodesDegree) {
                System.out.println(Arrays.toString(degree));
            }
            System.out.println("powerDegree size: " + powerDegree.length);
        } else {
            Set<Integer> sourcePowersSet = Arrays.stream(params.getQueryPowers())
                    .boxed()
                    .collect(Collectors.toCollection(HashSet::new));
            PowerNode[] powerNodes = PowerUtils.computePowers(sourcePowersSet, params.getMaxPartitionSizePerBin());
            powerDegree = Arrays.stream(powerNodes).map(PowerNode::toIntArray).toArray(int[][]::new);
        }
//        int labelPartitionCount = CommonUtils.getUnitNum((labelByteLength + ivByteLength) * Byte.SIZE,
//                (PirUtils.getBitLength(params.getPlainModulus()) - 1) * params.getItemEncodedSlotSize());
//        System.out.println("labelPartitionCount size: " + labelPartitionCount);
        IntStream queryIntStream = parallel ?
                IntStream.range(0, params.getCiphertextNum()).parallel() : IntStream.range(0, params.getCiphertextNum());
        List<byte[]> queryPowers = queryIntStream
                .mapToObj(i -> Cmg21KwPirNativeUtils.computeEncryptedPowers(
                        params.getEncryptionParams(),
                        relinKeys,
                        encryptedQuery.subList(i * params.getQueryPowers().length, (i + 1) * params.getQueryPowers().length),
                        powerDegree,
                        params.getQueryPowers(),
                        params.getPsLowDegree()))
                .flatMap(Collection::stream)
                .collect(Collectors.toCollection(ArrayList::new));

        System.out.println("queryPowers size: " + queryPowers.size() + "\nvalue: ");
        queryPowers.forEach(array -> System.out.println(Arrays.toString(array)));
        System.out.println("CiphertextNum size: " + params.getCiphertextNum());

        if (params.getPsLowDegree() > 0) {
            keywordResponsePayload = IntStream.range(0, params.getCiphertextNum())
                    .mapToObj(i ->
                            (parallel ? IntStream.range(0, partitionCount).parallel() : IntStream.range(0, partitionCount))
                                    .mapToObj(j ->
                                            Cmg21KwPirNativeUtils.optComputeMatches(
                                                    params.getEncryptionParams(),
                                                    publicKey,
                                                    relinKeys,
                                                    getEncodeDataFromRedis("serverKeywordEncode", i * partitionCount + j),
                                                    //serverKeywordEncode.get(i * partitionCount + j),
                                                    queryPowers.subList(i * powerDegree.length, (i + 1) * powerDegree.length),
                                                    params.getPsLowDegree()))
                                    .toArray(byte[][]::new))
                    .flatMap(Arrays::stream)
                    .collect(Collectors.toCollection(ArrayList::new));

//            System.out.println("keywordResponsePayload size: " + keywordResponsePayload.size() + " " + keywordResponsePayload.get(0).length);
            labelResponsePayload = IntStream.range(0, params.getCiphertextNum())
                    .mapToObj(i ->
                            (parallel ? IntStream.range(0, partitionCount * labelPartitionCount).parallel() :
                                    IntStream.range(0, partitionCount * labelPartitionCount))
                                    .mapToObj(j ->
                                            Cmg21KwPirNativeUtils.optComputeMatches(
                                                    params.getEncryptionParams(),
                                                    publicKey,
                                                    relinKeys,
                                                    getEncodeDataFromRedis("serverLabelEncode", i * partitionCount * labelPartitionCount + j),
                                                    //serverLabelEncode.get(i * partitionCount * labelPartitionCount + j),
                                                    queryPowers.subList(i * powerDegree.length, (i + 1) * powerDegree.length),
                                                    params.getPsLowDegree()))
                                    .toArray(byte[][]::new))
                    .flatMap(Arrays::stream)
                    .collect(Collectors.toCollection(ArrayList::new));
//            System.out.println("labelResponsePayload size: " + labelResponsePayload.size() + "\nvalue: ");
//            labelResponsePayload.forEach(array-> System.out.println(Arrays.toString(array)));
        } else if (params.getPsLowDegree() == 0) {
            keywordResponsePayload = IntStream.range(0, params.getCiphertextNum())
                    .mapToObj(i ->
                            (parallel ? IntStream.range(0, partitionCount).parallel() : IntStream.range(0, partitionCount))
                                    .mapToObj(j ->
                                            Cmg21KwPirNativeUtils.naiveComputeMatches(
                                                    params.getEncryptionParams(),
                                                    publicKey,
                                                    serverKeywordEncode.get(i * partitionCount + j),
                                                    queryPowers.subList(i * powerDegree.length, (i + 1) * powerDegree.length)))
                                    .toArray(byte[][]::new))
                    .flatMap(Arrays::stream)
                    .collect(Collectors.toCollection(ArrayList::new));
            labelResponsePayload = IntStream.range(0, params.getCiphertextNum())
                    .mapToObj(i ->
                            (parallel ? IntStream.range(0, partitionCount * labelPartitionCount).parallel() :
                                    IntStream.range(0, partitionCount * labelPartitionCount))
                                    .mapToObj(j ->
                                            Cmg21KwPirNativeUtils.naiveComputeMatches(
                                                    params.getEncryptionParams(),
                                                    publicKey,
                                                    serverLabelEncode.get(i * partitionCount + j),
                                                    queryPowers.subList(i * powerDegree.length, (i + 1) * powerDegree.length)))
                                    .toArray(byte[][]::new))
                    .flatMap(Arrays::stream)
                    .collect(Collectors.toCollection(ArrayList::new));
        } else {
            throw new MpcAbortException("ps_low_degree is incorrect.");
        }
    }

首先需要计算power,
innerPowersSet: [1, 18, 3, 11]
outerPowersSet: [1, 5]

innerPowerNodes: --size: 44

[{"depth":0,"leftPower":0,"power":1,"rightPower":0},{"depth":1,"leftPower":1,"power":2,"rightPower":1},{"depth":0,"leftPower":0,"power":3,"rightPower":0},{"depth":1,"leftPower":1,"power":4,"rightPower":3},{"depth":2,"leftPower":1,"power":5,"rightPower":4},{"depth":1,"leftPower":3,"power":6,"rightPower":3},{"depth":2,"leftPower":1,"power":7,"rightPower":6},{"depth":2,"leftPower":2,"power":8,"rightPower":6},{"depth":2,"leftPower":3,"power":9,"rightPower":6},{"depth":2,"leftPower":4,"power":10,"rightPower":6},{"depth":0,"leftPower":0,"power":11,"rightPower":0},{"depth":1,"leftPower":1,"power":12,"rightPower":11},{"depth":2,"leftPower":1,"power":13,"rightPower":12},{"depth":1,"leftPower":3,"power":14,"rightPower":11},{"depth":2,"leftPower":1,"power":15,"rightPower":14},{"depth":2,"leftPower":2,"power":16,"rightPower":14},{"depth":2,"leftPower":3,"power":17,"rightPower":14},{"depth":0,"leftPower":0,"power":18,"rightPower":0},{"depth":1,"leftPower":1,"power":19,"rightPower":18},{"depth":2,"leftPower":1,"power":20,"rightPower":19},{"depth":1,"leftPower":3,"power":21,"rightPower":18},{"depth":1,"leftPower":11,"power":22,"rightPower":11},{"depth":2,"leftPower":1,"power":23,"rightPower":22},{"depth":2,"leftPower":2,"power":24,"rightPower":22},{"depth":2,"leftPower":3,"power":25,"rightPower":22},{"depth":2,"leftPower":4,"power":26,"rightPower":22},{"depth":2,"leftPower":6,"power":27,"rightPower":21},{"depth":2,"leftPower":6,"power":28,"rightPower":22},{"depth":1,"leftPower":11,"power":29,"rightPower":18},{"depth":2,"leftPower":1,"power":30,"rightPower":29},{"depth":2,"leftPower":2,"power":31,"rightPower":29},{"depth":2,"leftPower":3,"power":32,"rightPower":29},{"depth":2,"leftPower":4,"power":33,"rightPower":29},{"depth":2,"leftPower":12,"power":34,"rightPower":22},{"depth":2,"leftPower":6,"power":35,"rightPower":29},{"depth":1,"leftPower":18,"power":36,"rightPower":18},{"depth":2,"leftPower":1,"power":37,"rightPower":36},{"depth":2,"leftPower":2,"power":38,"rightPower":36},{"depth":2,"leftPower":3,"power":39,"rightPower":36},{"depth":2,"leftPower":4,"power":40,"rightPower":36},{"depth":2,"leftPower":12,"power":41,"rightPower":29},{"depth":2,"leftPower":6,"power":42,"rightPower":36},{"depth":2,"leftPower":14,"power":43,"rightPower":29},{"depth":2,"leftPower":22,"power":44,"rightPower":22}]

PowerNode[] outerPowerNodes = PowerUtils.computePowers(
        outerPowersSet, params.getMaxPartitionSizePerBin() / (params.getPsLowDegree() + 1));

1304 /45=28

outerPowerNodes: --size: 28
[{"depth":0,"leftPower":0,"power":1,"rightPower":0},{"depth":1,"leftPower":1,"power":2,"rightPower":1},{"depth":2,"leftPower":2,"power":3,"rightPower":1},{"depth":2,"leftPower":2,"power":4,"rightPower":2},{"depth":0,"leftPower":0,"power":5,"rightPower":0},{"depth":1,"leftPower":1,"power":6,"rightPower":5},{"depth":2,"leftPower":1,"power":7,"rightPower":6},{"depth":2,"leftPower":2,"power":8,"rightPower":6},{"depth":3,"leftPower":1,"power":9,"rightPower":8},{"depth":1,"leftPower":5,"power":10,"rightPower":5},{"depth":2,"leftPower":1,"power":11,"rightPower":10},{"depth":2,"leftPower":2,"power":12,"rightPower":10},{"depth":3,"leftPower":1,"power":13,"rightPower":12},{"depth":3,"leftPower":2,"power":14,"rightPower":12},{"depth":2,"leftPower":5,"power":15,"rightPower":10},{"depth":2,"leftPower":6,"power":16,"rightPower":10},{"depth":3,"leftPower":1,"power":17,"rightPower":16},{"depth":3,"leftPower":2,"power":18,"rightPower":16},{"depth":3,"leftPower":3,"power":19,"rightPower":16},{"depth":2,"leftPower":10,"power":20,"rightPower":10},{"depth":3,"leftPower":1,"power":21,"rightPower":20},{"depth":3,"leftPower":2,"power":22,"rightPower":20},{"depth":3,"leftPower":3,"power":23,"rightPower":20},{"depth":3,"leftPower":4,"power":24,"rightPower":20},{"depth":3,"leftPower":5,"power":25,"rightPower":20},{"depth":3,"leftPower":6,"power":26,"rightPower":20},{"depth":3,"leftPower":7,"power":27,"rightPower":20},{"depth":3,"leftPower":8,"power":28,"rightPower":20}]

innerPowerNodesDegree: 

 可以看出选取的是innerPowerNodes的 [leftPower, rightPower]组成innerPowerNodesDegree

outerPowerNodesDegree: 

powerDegree size: 72 是innerPowerNodesDegree和outerPowerNodesDegree拼接在一起

queryPowers size: 288 (4*72 map的结果)

   List<byte[]> queryPowers = queryIntStream
                .mapToObj(i -> Cmg21KwPirNativeUtils.computeEncryptedPowers(
                        params.getEncryptionParams(),
                        relinKeys,
                        encryptedQuery.subList(i * params.getQueryPowers().length, (i + 1) * params.getQueryPowers().length),
                        powerDegree,
                        params.getQueryPowers(),
                        params.getPsLowDegree()))
                .flatMap(Collection::stream)
                .collect(Collectors.toCollection(ArrayList::new));

computeEncryptedPowers这个方法在

edu_alibaba_mpc4j_s2pc_pir_keyword_cmg21_Cmg21KwPirNativeUtils.cpp

[[maybe_unused]] JNIEXPORT
jobject JNICALL Java_edu_alibaba_mpc4j_s2pc_pir_keyword_cmg21_Cmg21KwPirNativeUtils_computeEncryptedPowers(
    JNIEnv *env, jclass, jbyteArray parms_bytes, jbyteArray relin_keys_bytes, jobject query_list,
    jobjectArray jparent_powers, jintArray jsource_power_index, jint ps_low_power) {
    EncryptionParameters parms = deserialize_encryption_parms(env, parms_bytes);
    SEALContext context = SEALContext(parms);
    RelinKeys relin_keys = deserialize_relin_keys(env, relin_keys_bytes, context);
    Evaluator evaluator(context);
    vector<Ciphertext> query = deserialize_ciphertexts(env, query_list, context);
    // compute all the powers of the receiver's input.
    jint* index_ptr = env->GetIntArrayElements(jsource_power_index, JNI_FALSE);
    vector<uint32_t> source_power_index;
    source_power_index.reserve(env->GetArrayLength(jsource_power_index));
    for (uint32_t i = 0; i < env->GetArrayLength(jsource_power_index); i++) {
        source_power_index.push_back(index_ptr[i]);
    }
    uint32_t target_power_size = env->GetArrayLength(jparent_powers);
    vector<vector<uint32_t>> parent_powers(target_power_size);
    for (uint32_t i = 0; i < target_power_size; i++) {
        parent_powers[i].reserve(2);
        auto rows = (jintArray) env->GetObjectArrayElement(jparent_powers, (jint) i);
        jint* cols = env->GetIntArrayElements(rows, JNI_FALSE);
        parent_powers[i].push_back(cols[0]);
        parent_powers[i].push_back(cols[1]);
    }
    vector<Ciphertext> encrypted_powers = compute_encrypted_powers(parms, query, parent_powers, source_power_index, ps_low_power, relin_keys);
    return serialize_ciphertexts(env, encrypted_powers);
}

核心方法compute_encrypted_powers在apsi.cpp
 

vector<Ciphertext> compute_encrypted_powers(const EncryptionParameters& parms, vector<Ciphertext> query,
                                            vector<vector<uint32_t>> parent_powers, vector<uint32_t> source_power_index,
                                            uint32_t ps_low_power, const RelinKeys& relin_keys) {
    SEALContext context(parms);
    Evaluator evaluator(context);
    uint32_t target_power_size = parent_powers.size();
    auto high_powers_parms_id = get_parms_id_for_chain_idx(context, 1);
    auto low_powers_parms_id = get_parms_id_for_chain_idx(context, 2);
    vector<Ciphertext> encrypted_powers;
    encrypted_powers.resize(target_power_size);
    if (ps_low_power > 0) {
        // Paterson-Stockmeyer algorithm
        uint32_t ps_high_degree = ps_low_power + 1;
        for (uint32_t i = 0; i < query.size(); i++) {
            if (source_power_index[i] <= ps_low_power) {
                encrypted_powers[source_power_index[i] - 1] = query[i];
            } else {
                encrypted_powers[ps_low_power + (source_power_index[i] / ps_high_degree) - 1] = query[i];
            }
        }
        print_encrypted_powers(encrypted_powers, "After initial assignment");
        for (uint32_t i = 0; i < ps_low_power; i++) {
            if (parent_powers[i][1] != 0) {
                if (parent_powers[i][0] - 1 == parent_powers[i][1] - 1) {
                    evaluator.square(encrypted_powers[parent_powers[i][0] - 1], encrypted_powers[i]);
                } else {
                    evaluator.multiply(encrypted_powers[parent_powers[i][0] - 1],
                                       encrypted_powers[parent_powers[i][1] - 1], encrypted_powers[i]);
                }
                evaluator.relinearize_inplace(encrypted_powers[i], relin_keys);
            }
        }
        for (uint32_t i = ps_low_power; i < target_power_size; i++) {
            if (parent_powers[i][1] != 0) {
                if (parent_powers[i][0] - 1 == parent_powers[i][1] - 1) {
                    evaluator.square(encrypted_powers[parent_powers[i][0] - 1 + ps_low_power], encrypted_powers[i]);
                } else {
                    evaluator.multiply(encrypted_powers[parent_powers[i][0] - 1 + ps_low_power],
                                       encrypted_powers[parent_powers[i][1] - 1 + ps_low_power], encrypted_powers[i]);
                }
                evaluator.relinearize_inplace(encrypted_powers[i], relin_keys);
            }
        }
        for (uint32_t i = 0; i < ps_low_power; i++) {
            // Low powers must be at a higher level than high powers
            evaluator.mod_switch_to_inplace(encrypted_powers[i], low_powers_parms_id);
            // Low powers must be in NTT form
            evaluator.transform_to_ntt_inplace(encrypted_powers[i]);
        }
        for (uint32_t i = ps_low_power; i < target_power_size; i++) {
            // High powers are only modulus switched
            evaluator.mod_switch_to_inplace(encrypted_powers[i], high_powers_parms_id);
        }
    } else {
        // naive algorithm
        for (uint32_t i = 0; i < query.size(); i++) {
            encrypted_powers[source_power_index[i] - 1] = query[i];
        }
        for (uint32_t i = 0; i < target_power_size; i++) {
            if (parent_powers[i][1] != 0) {
                if (parent_powers[i][0] - 1 == parent_powers[i][1] - 1) {
                    evaluator.square(encrypted_powers[parent_powers[i][0] - 1], encrypted_powers[i]);
                } else {
                    evaluator.multiply(encrypted_powers[parent_powers[i][0] - 1],
                                       encrypted_powers[parent_powers[i][1] - 1], encrypted_powers[i]);
                }
                evaluator.relinearize_inplace(encrypted_powers[i], relin_keys);
            }
        }
        for (auto &encrypted_power: encrypted_powers) {
            // Only one ciphertext-plaintext multiplication is needed after this
            evaluator.mod_switch_to_inplace(encrypted_power, high_powers_parms_id);
            // All powers must be in NTT form
            evaluator.transform_to_ntt_inplace(encrypted_power);
        }
    }
    return encrypted_powers;
}
ps_low_power: 44
high_powers_parms_id: 9726251831826291138 6271946707726836937 2621402725302960276 275072808666134953 
low_powers_parms_id: 15386047074277402743 18044400843720409247 6471315762005443664 14829797434412469570 
parent_powers:(和powerDegree是一样的)
0 0 
1 1 
0 0 
1 3 
1 4 
3 3 
1 6 
2 6 
3 6 
4 6 
0 0 
1 11 
1 12 
3 11 
1 14 
2 14 
3 14 
0 0 
1 18 
1 19 
3 18 
11 11 
1 22 
2 22 
3 22 
4 22 
6 21 
6 22 
11 18 
1 29 
2 29 
3 29 
4 29 
12 22 
6 29 
18 18 
1 36 
2 36 
3 36 
4 36 
12 29 
6 36 
14 29 
22 22 
0 0 
1 1 
2 1 
2 2 
0 0 
1 5 
1 6 
2 6 
1 8 
5 5 
1 10 
2 10 
1 12 
2 12 
5 10 
6 10 
1 16 
2 16 
3 16 
10 10 
1 20 
2 20 
3 20 
4 20 
5 20 
6 20 
7 20 
8 20 
source_power_index: 1 3 11 18 45 225 
第一步初始化

每组查询是6个,把每组查询的encrypted_powers数组中对应的source_power_index的位置进行相应的填充。

query[0] : [size=2, coeff_modulus_size=3]
encrypted_powers[0] : [size=2, coeff_modulus_size=3]
query[1] : [size=2, coeff_modulus_size=3]
encrypted_powers[2] : [size=2, coeff_modulus_size=3]
query[2] : [size=2, coeff_modulus_size=3]
encrypted_powers[10] : [size=2, coeff_modulus_size=3]
query[3] : [size=2, coeff_modulus_size=3]
encrypted_powers[17] : [size=2, coeff_modulus_size=3]
query[4] : [size=2, coeff_modulus_size=3]
encrypted_powers[44] : [size=2, coeff_modulus_size=3]
query[5] : [size=2, coeff_modulus_size=3]
48 = 44+225/45-1
encrypted_powers[48] : [size=2, coeff_modulus_size=3]
 第二步是对其他的power进行填充
parent_powers[0][0]: 0, 0
如果parent_powers[i][1]是0的话就跳过,也就是说初始化填充的那6个是跳过的

如果parent_powers[i][0]和parent_powers[i][1]相等

encrypted_powers[i] = encrypted_powers[parent_powers[i][0] - 1]^2
比如parent_powers[5] 3 3 
encrypted_powers[5] = encrypted_powers[2]^2
parent_powers[2] 0 0 是之前初始化过的

如果parent_powers[i][0]和parent_powers[i][1]不相等

encrypted_powers[i] = encrypted_powers[parent_powers[i][0] - 1]*encrypted_powers[parent_powers[i][1] - 1]
比如parent_powers[25] 4 22 
encrypted_powers[25] = encrypted_powers[3]*encrypted_powers[21]

parent_powers[3] 1 3 
encrypted_powers[3] = encrypted_powers[0]*encrypted_powers[2] 初始化过的

parent_powers[21] 11 11 
encrypted_powers[21] = encrypted_powers[10]^2 初始化过的

 第三步

对于低幂项部分,由于其系数较小,需要对其进行模切换操作,使其使用低级别的参数来进行计算。然后,需要将其转换为NTT(Number Theoretic Transform)形式。

对于高幂项部分,只需要进行模切换操作即可,使其使用高级别的参数来进行计算。

CiphertextNum size: 4

keywordResponsePayload

        if (params.getPsLowDegree() > 0) {
            keywordResponsePayload = IntStream.range(0, params.getCiphertextNum())
                    .mapToObj(i ->
                            (parallel ? IntStream.range(0, partitionCount).parallel() : IntStream.range(0, partitionCount))
                                    .mapToObj(j ->
                                            Cmg21KwPirNativeUtils.optComputeMatches(
                                                    params.getEncryptionParams(),
                                                    publicKey,
                                                    relinKeys,
                                                    getEncodeDataFromRedis("serverKeywordEncode", i * partitionCount + j),
                                                    //serverKeywordEncode.get(i * partitionCount + j),
                                                    queryPowers.subList(i * powerDegree.length, (i + 1) * powerDegree.length),
                                                    params.getPsLowDegree()))
                                    .toArray(byte[][]::new))
                    .flatMap(Arrays::stream)
                    .collect(Collectors.toCollection(ArrayList::new));

//            System.out.println("keywordResponsePayload size: " + keywordResponsePayload.size() + " " + keywordResponsePayload.get(0).length);
            labelResponsePayload = IntStream.range(0, params.getCiphertextNum())
                    .mapToObj(i ->
                            (parallel ? IntStream.range(0, partitionCount * labelPartitionCount).parallel() :
                                    IntStream.range(0, partitionCount * labelPartitionCount))
                                    .mapToObj(j ->
                                            Cmg21KwPirNativeUtils.optComputeMatches(
                                                    params.getEncryptionParams(),
                                                    publicKey,
                                                    relinKeys,
                                                    getEncodeDataFromRedis("serverLabelEncode", i * partitionCount * labelPartitionCount + j),
                                                    //serverLabelEncode.get(i * partitionCount * labelPartitionCount + j),
                                                    queryPowers.subList(i * powerDegree.length, (i + 1) * powerDegree.length),
                                                    params.getPsLowDegree()))
                                    .toArray(byte[][]::new))
                    .flatMap(Arrays::stream)
                    .collect(Collectors.toCollection(ArrayList::new));
//            System.out.println("labelResponsePayload size: " + labelResponsePayload.size() + "\nvalue: ");
//            labelResponsePayload.forEach(array-> System.out.println(Arrays.toString(array)));
        }

optComputeMatches这个方法在

edu_alibaba_mpc4j_s2pc_pir_keyword_cmg21_Cmg21KwPirNativeUtils.cpp
[[maybe_unused]] JNIEXPORT
jbyteArray JNICALL Java_edu_alibaba_mpc4j_s2pc_pir_keyword_cmg21_Cmg21KwPirNativeUtils_optComputeMatches(
    JNIEnv *env, jclass, jbyteArray parms_bytes, jbyteArray pk_bytes, jbyteArray relin_keys_bytes,
    jobject database_coeffs, jobject query_list, jint ps_low_power) {
    EncryptionParameters parms = deserialize_encryption_parms(env, parms_bytes);
    SEALContext context(parms);
    PublicKey public_key = deserialize_public_key(env, pk_bytes, context);
    RelinKeys relin_keys = deserialize_relin_keys(env, relin_keys_bytes, context);
    Evaluator evaluator(context);
    // encrypted query powers
    vector<Ciphertext> query_powers = deserialize_ciphertexts(env, query_list, context);
    auto low_powers_parms_id = get_parms_id_for_chain_idx(context, 2);
    vector<Plaintext> plaintexts = deserialize_plaintexts(env, database_coeffs, context);
    Ciphertext f_evaluated = polynomial_evaluation(parms, query_powers, plaintexts, ps_low_power, relin_keys, public_key);
    try_clear_irrelevant_bits(parms, f_evaluated);
    return serialize_ciphertext(env, f_evaluated);
}
核心方法polynomial_evaluation在apsi.cpp
Ciphertext polynomial_evaluation(const EncryptionParameters& parms, vector<Ciphertext> encrypted_powers,
                                 vector<Plaintext> coeff_plaintexts, uint32_t ps_low_power, const RelinKeys& relin_keys,
                                 const PublicKey& public_key) {
    SEALContext context(parms);
    Encryptor encryptor(context, public_key);
    Evaluator evaluator(context);
    auto parms_id = get_parms_id_for_chain_idx(context, 1);
    uint32_t ps_high_degree = ps_low_power + 1;
    uint32_t degree = coeff_plaintexts.size() - 1;
    Ciphertext f_evaluated, cipher_temp, temp_in;
    f_evaluated.resize(context, parms_id, 3);
    f_evaluated.is_ntt_form() = false;
    uint32_t ps_high_degree_powers = degree / ps_high_degree;
    // Calculate polynomial for i = 1,...,ps_high_degree_powers-1
    for (uint32_t i = 1; i < ps_high_degree_powers; i++) {
        // Evaluate inner polynomial. The free term is left out and added later on.
        // The evaluation result is stored in temp_in.
        for (uint32_t j = 1; j < ps_high_degree; j++) {
            evaluator.multiply_plain(encrypted_powers[j - 1], coeff_plaintexts[j + i * ps_high_degree], cipher_temp);
            if (j == 1) {
                temp_in = cipher_temp;
            } else {
                evaluator.add_inplace(temp_in, cipher_temp);
            }
        }
        // Transform inner polynomial to coefficient form
        evaluator.transform_from_ntt_inplace(temp_in);
        evaluator.mod_switch_to_inplace(temp_in, parms_id);
        // The high powers are already in coefficient form
        evaluator.multiply_inplace(temp_in, encrypted_powers[i - 1 + ps_low_power]);
        evaluator.add_inplace(f_evaluated, temp_in);
    }
    // Calculate polynomial for i = ps_high_degree_powers.
    // Done separately because here the degree of the inner poly is degree % ps_high_degree.
    // Once again, the free term will only be added later on.
    if (degree % ps_high_degree > 0 && ps_high_degree_powers > 0) {
        for (uint32_t i = 1; i <= degree % ps_high_degree; i++) {
            evaluator.multiply_plain(encrypted_powers[i - 1],
                                     coeff_plaintexts[ps_high_degree * ps_high_degree_powers + i],
                                     cipher_temp);
            if (i == 1) {
                temp_in = cipher_temp;
            } else {
                evaluator.add_inplace(temp_in, cipher_temp);
            }
        }
        // Transform inner polynomial to coefficient form
        evaluator.transform_from_ntt_inplace(temp_in);
        evaluator.mod_switch_to_inplace(temp_in, parms_id);
        // The high powers are already in coefficient form
        evaluator.multiply_inplace(temp_in, encrypted_powers[ps_high_degree_powers - 1 + ps_low_power]);
        evaluator.add_inplace(f_evaluated, temp_in);
    }
    // Relinearize sum of ciphertext-ciphertext products
    if (!f_evaluated.is_transparent()) {
        evaluator.relinearize_inplace(f_evaluated, relin_keys);
    }
    // Calculate inner polynomial for i=0.
    // Done separately since there is no multiplication with a power of high-degree
    uint32_t length = ps_high_degree_powers == 0 ? degree : ps_low_power;
    for (uint32_t j = 1; j <= length; j++) {
        evaluator.multiply_plain(encrypted_powers[j-1], coeff_plaintexts[j], cipher_temp);
        evaluator.transform_from_ntt_inplace(cipher_temp);
        evaluator.mod_switch_to_inplace(cipher_temp, parms_id);
        evaluator.add_inplace(f_evaluated, cipher_temp);
    }
    // Add the constant coefficients of the inner polynomials multiplied by the respective powers of high-degree
    for (uint32_t i = 1; i < ps_high_degree_powers + 1; i++) {
        evaluator.multiply_plain(encrypted_powers[i - 1 + ps_low_power], coeff_plaintexts[ps_high_degree * i], cipher_temp);
        evaluator.mod_switch_to_inplace(cipher_temp, parms_id);
        evaluator.add_inplace(f_evaluated, cipher_temp);
    }
    // Add the constant coefficient
    if (degree > 0) {
        evaluator.add_plain_inplace(f_evaluated, coeff_plaintexts[0]);
    } else {
        encryptor.encrypt(coeff_plaintexts[0], f_evaluated);
    }
    while (f_evaluated.parms_id() != context.last_parms_id()) {
        evaluator.mod_switch_to_next_inplace(f_evaluated);
    }
    return f_evaluated;
}

如果修改了apsi.cpp文件,需要上传apsi.cpp到/fourQTest/mpc4j-native-fhe,然后

cd cmake-build-release  //删除已经生成的libmpc4j-native-fhe.so

cmake ..

make

 rm -rf /home/fourQTest/mpc4j-native-fhe/cmake-build-release/*

 cd /home/fourQTest/mpc4j-native-fhe/cmake-build-release
cmake -DCMAKE_CXX_COMPILER=g++ -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_STANDARD=17 ..

client端查询

1 将查询数据盲化

2 进行布谷hash

比如查询数据是2,就根据hash函数得到一个BinIndex,根据这个BinIndex放进hash桶中,hash桶仍然有6552个,且每个桶中只有一个元素,其他桶进行填充

3 产生查询generateQuery

3.1encodeQuery编码查询,将哈希桶编码成items[4][8192]

外循环是密文
内循环是每个密文的Item数,因为一个item实际上就对应一个哈希桶,所以取出桶中的数据,并将数据编译成长度为5的数组(slot=5),并将items中该密文行的对应5列进行赋值。

3.1.1计算Power的过程是
对于每个密文(已分slot)进行构造一个二维数组 result[6][8192],根据[1, 3, 11, 18, 45, 225],对于密文中的每个Item(已分slot)^对应的幂,比如3次方幂
编码查询返回的是List<long[][]> list长度是4,[6][8192]

 3.2  用私钥加密明文Cmg21KwPirNativeUtils.generateQuery

[[maybe_unused]] JNIEXPORT
jobject JNICALL Java_edu_alibaba_mpc4j_s2pc_pir_keyword_cmg21_Cmg21KwPirNativeUtils_generateQuery(
    JNIEnv *env, jclass, jbyteArray parms_bytes, jbyteArray pk_bytes, jbyteArray sk_bytes, jobjectArray coeffs_array) {
    EncryptionParameters parms = deserialize_encryption_parms(env, parms_bytes);
    SEALContext context = SEALContext(parms);
    PublicKey public_key = deserialize_public_key(env, pk_bytes, context);
    SecretKey secret_key = deserialize_secret_key(env, sk_bytes, context);
    vector<Plaintext> plain_query = deserialize_plaintexts_from_coeff(env, coeffs_array, context);
    BatchEncoder encoder(context);
    Encryptor encryptor(context, public_key);
    encryptor.set_secret_key(secret_key);
    vector<Serializable<Ciphertext>> query;
    for (auto & plaintext : plain_query) {
        query.push_back(encryptor.encrypt_symmetric(plaintext));
    }
    return serialize_ciphertexts(env, query);
}

核心方法encrypt_symmetric 其实是对明文进行加密,这个过程会对明文多项式进行缩放和取整
Cmg21KwPirNativeUtils.generateQuery返回的是 6 19602  List<byte[]>

generateQuery返回的query size: 24 196802
client端加解密是BFV方案,参考
BFV加密方案-CSDN博客
BFV原理(seal库)_seal库进行bvf加解密-CSDN博客

4 处理查询返回handleResponse

    private Map<ByteBuffer, ByteBuffer> handleResponse(List<byte[]> keyResponse, List<byte[]> valueResponse,
                                                       Map<ByteBuffer, ByteBuffer> prfKeyMap,
                                                       CuckooHashBin<ByteBuffer> cuckooHashBin)
            throws MpcAbortException {
        MpcAbortPreconditions.checkArgument(keyResponse.size() % params.getCiphertextNum() == 0);
        MpcAbortPreconditions.checkArgument(valueResponse.size() % params.getCiphertextNum() == 0);
        Stream<byte[]> keyResponseStream = keyResponse.stream();
        keyResponseStream = parallel ? keyResponseStream.parallel() : keyResponseStream;
        List<long[]> decryptedKeyResponse = keyResponseStream
                .map(i -> Cmg21KwPirNativeUtils.decodeReply(params.getEncryptionParams(), secretKey, i))
                .collect(Collectors.toCollection(ArrayList::new));
        Stream<byte[]> valueResponseStream = valueResponse.stream();
        valueResponseStream = parallel ? valueResponseStream.parallel() : valueResponseStream;
        List<long[]> decryptedValueResponse = valueResponseStream
                .map(i -> Cmg21KwPirNativeUtils.decodeReply(params.getEncryptionParams(), secretKey, i))
                .collect(Collectors.toCollection(ArrayList::new));
        return recoverPirResult(decryptedKeyResponse, decryptedValueResponse, prfKeyMap, cuckooHashBin);
    }
recoverPirResult方法

    private Map<ByteBuffer, ByteBuffer> recoverPirResult(List<long[]> decryptedKeyReply,
                                                         List<long[]> decryptedValueReply,
                                                         Map<ByteBuffer, ByteBuffer> prfKeyMap,
                                                         CuckooHashBin<ByteBuffer> cuckooHashBin) {
        Map<ByteBuffer, ByteBuffer> resultMap = new HashMap<>(retrievalKeySize);
        System.out.println("decryptedKeyReply size: "+decryptedKeyReply.size()+" "+decryptedKeyReply.get(0).length);
        System.out.println("decryptedValueReply size: "+decryptedValueReply.size()+" "+decryptedValueReply.get(0).length);
        int itemPartitionNum = decryptedKeyReply.size() / params.getCiphertextNum();
        int labelPartitionNum = CommonUtils.getUnitNum((valueByteLength + ivByteLength) * Byte.SIZE,
                (LongUtils.ceilLog2(params.getPlainModulus()) - 1) * params.getItemEncodedSlotSize());
        System.out.println("itemPartitionNum size: "+itemPartitionNum);
        System.out.println("labelPartitionNum size: "+labelPartitionNum);
        int shiftBits = CommonUtils.getUnitNum((valueByteLength + ivByteLength) * Byte.SIZE,
                params.getItemEncodedSlotSize() * labelPartitionNum);

        for (int i = 0; i < decryptedKeyReply.size(); i++) {
            List<Integer> matchedItem = new ArrayList<>();
            for (int j = 0; j < params.getItemEncodedSlotSize() * params.getItemPerCiphertext(); j++) {
                System.out.println("decryptedKeyReply.get(" + i + ")[" + j + "]: " + decryptedKeyReply.get(i)[j]);
                if (decryptedKeyReply.get(i)[j] == 0) {
                    matchedItem.add(j);
                }
            }
            System.out.println("Current matchedItem list: " + matchedItem);
            for (int j = 0; j < matchedItem.size() - params.getItemEncodedSlotSize() + 1; j++) {
                if (matchedItem.get(j) % params.getItemEncodedSlotSize() == 0) {
                    if (matchedItem.get(j + params.getItemEncodedSlotSize() - 1) - matchedItem.get(j) ==
                            params.getItemEncodedSlotSize() - 1) {
                        int hashBinIndex = matchedItem.get(j) / params.getItemEncodedSlotSize() + (i / itemPartitionNum)
                                * params.getItemPerCiphertext();

                        System.out.println("Valid matchedItem found from index: " + matchedItem.get(j)+"--j: "+j);
                        System.out.println("Hash bin index: " + hashBinIndex);

                        BigInteger label = BigInteger.ZERO;
                        int index = 0;
                        for (int l = 0; l < labelPartitionNum; l++) {
                            for (int k = 0; k < params.getItemEncodedSlotSize(); k++) {
                                BigInteger temp = BigInteger.valueOf(
                                        decryptedValueReply.get(i * labelPartitionNum + l)[matchedItem.get(j + k)]
                                ).shiftLeft(shiftBits * index);
                                label = label.add(temp);
                                System.out.println("Label partial value: " + temp + " after shifting " + (shiftBits * index) + " bits");
                                index++;
                            }
                        }
                        System.out.println("Reconstructed label: " + label);
                        byte[] oprf = cuckooHashBin.getHashBinEntry(hashBinIndex).getItem().array();
                        byte[] keyBytes = new byte[CommonConstants.BLOCK_BYTE_LENGTH];
                        System.arraycopy(oprf, 0, keyBytes, 0, CommonConstants.BLOCK_BYTE_LENGTH);
                        byte[] ciphertextLabel = BigIntegerUtils.nonNegBigIntegerToByteArray(
                                label, valueByteLength + ivByteLength
                        );
                        byte[] paddingCipher = BytesUtils.paddingByteArray(
                                ciphertextLabel, valueByteLength + CommonConstants.BLOCK_BYTE_LENGTH
                        );
                        byte[] plaintextLabel = streamCipher.ivDecrypt(keyBytes, paddingCipher);
                        System.out.println("Decrypted plaintext label: " + Arrays.toString(plaintextLabel));

                        resultMap.put(
                                prfKeyMap.get(cuckooHashBin.getHashBinEntry(hashBinIndex).getItem()),
                                ByteBuffer.wrap(plaintextLabel)
                        );
                        System.out.println("Added to resultMap: key=" + prfKeyMap.get(cuckooHashBin.getHashBinEntry(hashBinIndex).getItem()) +
                                ", value=" + ByteBuffer.wrap(plaintextLabel));
                        j = j + params.getItemEncodedSlotSize() - 1;
                    }
                }
            }
        }
        return resultMap;
    }

decryptedKeyReply size: 12 8192
decryptedValueReply size: 12 8192
itemPartitionNum size: 3
labelPartitionNum size: 1

matchedItem.get(j): 2905

params.getItemEncodedSlotSize(): 5

i: 0

itemPartitionNum: 3

params.getItemPerCiphertext(): 1638

hashBinIndex = 2905/5+0/3*1638 = 581

Current matchedItem list: [1335, 2684, 2757, 2905, 2906, 2907, 2908, 2909, 2945, 3785, 4922, 4978]
Valid matchedItem found from index: 2905--j: 3
Hash bin index: 581
Label partial value: 0 after shifting 0 bits
Label partial value: 0 after shifting 13 bits
Label partial value: 0 after shifting 26 bits
Label partial value: 0 after shifting 39 bits
Label partial value: 0 after shifting 52 bits
Reconstructed label: 0
Decrypted plaintext label: [91, -85, 98, -36, 70, -86, 99, -92]

这段代码实现了一个恢复Private Information Retrieval (PIR)结果的过程。具体来说,它从解密的密钥回复和解密的值回复中恢复实际的数据,并将其存储在结果映射 (resultMap) 中。以下是代码的详细分析:

参数与初始化
  • decryptedKeyReply: 解密后的密钥回复列表。
  • decryptedValueReply: 解密后的值回复列表。
  • prfKeyMap: 伪随机函数(PRF)键映射,用于将哈希桶中的项映射到实际的密钥。
  • cuckooHashBin: 布谷鸟哈希桶,存储数据项的哈希桶。

代码首先创建一个空的结果映射 resultMap,用于存储恢复的数据。

分区数量计算
  • itemPartitionNum: 计算项的分区数量,这是通过将 decryptedKeyReply 的大小除以 params.getCiphertextNum() 得到的。
  • labelPartitionNum: 计算标签的分区数量,这是通过调用 CommonUtils.getUnitNum 函数计算的。
  • shiftBits: 计算位移位数,同样使用 CommonUtils.getUnitNum 函数计算。

这些计算用来确定如何从解密的回复中提取实际的数据项。

遍历解密的密钥回复

代码遍历 decryptedKeyReply 列表的每一个元素,并尝试找到匹配的项:

  • 对于每一个解密的密钥回复,创建一个 matchedItem 列表,用于存储匹配的项的索引。
  • 遍历 decryptedKeyReply.get(i) 的每一个元素,如果元素为 0,则将索引添加到 matchedItem 列表中。
验证和提取有效项
  • 遍历 matchedItem 列表,检查是否存在连续的匹配项。
  • 通过计算哈希桶索引来确定当前匹配项的哈希桶位置。
    (这个哈希桶是client端第2步生成的,一个桶中只有一个元素,如果桶中的数匹配上了,这里可以确定是哪个数据匹配上了)
重建标签
  • 通过解密的值回复来重建标签。标签的重建过程涉及将部分标签值进行位移和合并。
  • 计算出完整的标签后,从布谷鸟哈希桶中获取相应的项,并解密其标签。
存储结果
  • 将解密后的标签存储在 resultMap 中,键是从 prfKeyMap 获取的,值是解密后的标签。

返回结果

最终,代码返回包含所有恢复数据的 resultMap

 server端,产生数组,keyMap (数组的索引,数组的内容),且只有key进行prf,label并没有参与prf

client端,查询是索引。

  • 18
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值