APSI--ORPF C++实现过程分析

OPRF

Oblivious Pseudo-Random Function(不经意的伪随机函数) ,简称OPRF。OPRF可以被视为一个带有密钥的哈希函数OPRF(s, -),只有发送方知道知道密钥s,接收方可以获取OPRF(s, X)而不知道函数OPRF(s, -)或密钥s,发送方也不会知道X,X是接收方发送的数据。

具体过程:

  1. 接收方将其数据X哈希到某个椭圆曲线点A上。
  2. 接收方选择一个秘密数r,计算点B = rA,并B发送给发送方。
  3. 发送方使用其密钥s计算C = sB,将其发送回接收方。
  4. 接收方收到C后,计算椭圆曲线阶的模反元素r^(-1),进一步计算r^(-1) C = r^(-1) srA = sA。
  5. 接收方然后从该点中提取出OPRF哈希值OPRF(s, X)。

发件人知道s,因此可以简单地将其项目替换{Y_i}{OPRF(s, Y_i)}。接收方需要与发送方通信获取{OPRF(s, X_i)};一旦接收器收到这些值,协议就可以如上所述继续进行。使用 OPRF,接收器了解其查询的部分内容是否匹配的问题就消失了。由于所有项目都使用只有发送者知道的散列函数进行散列,因此接收者不会从了解发送者的散列项目的部分中获得任何好处。事实上,发送者的数据集不是私人信息,原则上可以完整发送给接收者。同态加密仅保护接收者的数据。

这里还有一个必须提及的细节。选择OPRF(s, -)256 位输出,并用 表示其前 128 位ItemHash(s, -)。而不是{OPRF(s, X_i)}我们用作{ItemHash(s, X_i)}物品;稍后在标签加密中会给出原因。

APSI开源库中测试类的位置

tests/integration/src/stream_sender_receiver.cpp

分析其中RunUnlabeledTest方法

void RunUnlabeledTest(
    size_t sender_size,
    vector<pair<size_t, size_t>> client_total_and_int_sizes,
    const PSIParams &params,
    size_t num_threads,
    bool use_different_compression = false)
{
    // 设置日志输出级别和控制台是否禁用日志输出
    Log::SetConsoleDisabled(true);
    Log::SetLogLevel(Log::Level::info);

    // 设置线程池的线程数目
    ThreadPoolMgr::SetThreadCount(num_threads);

    // 创建发送方的项目
    vector<Item> sender_items;
    for (size_t i = 0; i < sender_size; i++) {
        sender_items.push_back({ i + 1, i + 1 });
    }

    // 创建发送方数据库
    auto sender_db = make_shared<SenderDB>(params, 0);
    auto oprf_key = sender_db->get_oprf_key();
    sender_db->set_data(sender_items);
    auto seal_context = sender_db->get_seal_context();

    // 创建数据流通道和接收方
    stringstream ss;
    StreamChannel chl(ss);
    Receiver receiver(params);

    // 遍历每个客户端的总大小和感兴趣项目大小
    for (auto client_total_and_int_size : client_total_and_int_sizes) {
        auto client_size = client_total_and_int_size.first;
        auto int_size = client_total_and_int_size.second;
        ASSERT_TRUE(int_size <= client_size);

        // 从发送方项目中随机选择感兴趣的项目作为接收方项目
        vector<Item> recv_int_items = rand_subset(sender_items, int_size);
        vector<Item> recv_items;
        for (auto item : recv_int_items) {
            recv_items.push_back(item);
        }
        for (size_t i = int_size; i < client_size; i++) {
            recv_items.push_back({ i + 1, ~(i + 1) });
        }

        // 创建 OPRF 接收器和请求
        oprf::OPRFReceiver oprf_receiver = Receiver::CreateOPRFReceiver(recv_items);
        Request oprf_request = Receiver::CreateOPRFRequest(oprf_receiver);

        // 发送 OPRF 请求
        ASSERT_NO_THROW(chl.send(move(oprf_request)));
        size_t bytes_sent = chl.bytes_sent();

        // 接收 OPRF 请求并处理响应
        OPRFRequest oprf_request2 =
            to_oprf_request(chl.receive_operation(nullptr, SenderOperationType::sop_oprf));
        size_t bytes_received = chl.bytes_received();
        ASSERT_EQ(bytes_sent, bytes_received);
        ASSERT_NO_THROW(Sender::RunOPRF(oprf_request2, oprf_key, chl));

        // 接收 OPRF 响应
        OPRFResponse oprf_response = to_oprf_response(chl.receive_response());
        vector<HashedItem> hashed_recv_items;
        vector<LabelKey> label_keys;
        tie(hashed_recv_items, label_keys) =
            Receiver::ExtractHashes(oprf_response, oprf_receiver);
        ASSERT_EQ(hashed_recv_items.size(), recv_items.size());

        // 创建查询并发送
        pair<Request, IndexTranslationTable> recv_query_pair =
            receiver.create_query(hashed_recv_items);
        QueryRequest recv_query = to_query_request(move(recv_query_pair.first));
        compr_mode_type expected_compr_mode = recv_query->compr_mode;

        // 检查是否使用不同的压缩模式,并切换压缩模式
        if (use_different_compression &&
            Serialization::IsSupportedComprMode(compr_mode_type::zlib) &&
            Serialization::IsSupportedComprMode(compr_mode_type::zstd)) {
            if (recv_query->compr_mode == compr_mode_type::zstd) {
                recv_query->compr_mode = compr_mode_type::zlib;
                expected_compr_mode = compr_mode_type::zlib;
            } else {
                recv_query->compr_mode = compr_mode_type::zstd;
                expected_compr_mode = compr_mode_type::zstd;
            }
        }

        IndexTranslationTable itt = move(recv_query_pair.second);
        chl.send(move(recv_query));

        // 接收查询并处理响应
        QueryRequest sender_query = to_query_request(chl.receive_operation(seal_context));
        Query query(move(sender_query), sender_db);
        ASSERT_EQ(expected_compr_mode, query.compr_mode());
        ASSERT_NO_THROW(Sender::RunQuery(query, chl));

        // 接收查询响应
        QueryResponse query_response = to_query_response(chl.receive_response());
        uint32_t package_count = query_response->package_count;

        // 接收所有结果部分并处理结果
        vector<ResultPart> rps;
        while (package_count--) {
            ASSERT_NO_THROW(rps.push_back(chl.receive_result(receiver.get_seal_context())));
        }
        auto query_result = receiver.process_result(label_keys, itt, rps);

        verify_unlabeled_results(query_result, recv_items, recv_int_items);
    }
}
  1. 设置日志级别和线程数量。
  2. 创建sender数据 sender_db,包括生成 oprf_key 和设置数据 sender_items。
  3. 初始化接收方 receiver 和接收通道 chl。
  4. 遍历每个客户端的总数和整数项大小,进行以下操作:
    • 从 sender_items 中随机选择一部分作为接收项 recv_int_items。
    • 创建 OPRF 接收器 oprf_receiver,并生成 OPRF 请求 oprf_request。
    • 发送 OPRF 请求,进行OPRF 的查询过程(发送方使用其秘密s来计算C = sB),
      接收响应,处理响应(接收器计算r^(-1)椭圆曲线阶数的逆模,并进一步计算r^(-1) C = r^(-1) srA = sA。提取 OPRF 哈希值OPRF(s, X))得到 hashed_recv_items 和 label_keys。
    • 创建查询请求 recv_query,并根据需要修改压缩模式。
    • 发送查询请求 recv_query。
    • 接收查询请求并处理响应,得到查询结果 query_result。
    • 验证未标记结果是否符合预期。

这段代码实现了一个完整的协议流程,涉及到 OPRF 协议、查询请求和响应、压缩模式等内容。在每个客户端上都执行了类似的操作,以确保通信和处理的正确性。

其中4.1-4.3为OPRF算法过程,下面进行这部分分析。

一、Sender::RunOPRF方法中的OPRF查询方法

 response_oprf->data = OPRFSender::ProcessQueries(oprf_request->data, key);
        vector<unsigned char> OPRFSender::ProcessQueries(
            gsl::span<const unsigned char> oprf_queries, const OPRFKey &oprf_key)
        {
            if (oprf_queries.size() % oprf_query_size) {
                throw invalid_argument("oprf_queries has invalid size");
            }

            STOPWATCH(sender_stopwatch, "OPRFSender::ProcessQueries");

            size_t query_count = oprf_queries.size() / oprf_query_size;
            vector<unsigned char> oprf_responses(query_count * oprf_response_size);

            auto oprf_in_ptr = oprf_queries.data();
            auto oprf_out_ptr = oprf_responses.data();

            ThreadPoolMgr tpm;
            size_t task_count = min<size_t>(ThreadPoolMgr::GetThreadCount(), query_count);
            vector<future<void>> futures(task_count);

            auto ProcessQueriesLambda = [&](size_t start_idx, size_t step) {
                for (size_t idx = start_idx; idx < query_count; idx += step) {
                    // Load the point from input buffer
                    ECPoint ecpt;
                    ecpt.load(ECPoint::point_save_span_const_type{
                        oprf_in_ptr + idx * oprf_query_size, oprf_query_size });

                    // Multiply with key
                    if (!ecpt.scalar_multiply(oprf_key.key_span(), true)) {
                        throw logic_error("scalar multiplication failed due to invalid query data");
                    }

                    // Save the result to oprf_responses
                    ecpt.save(ECPoint::point_save_span_type{
                        oprf_out_ptr + idx * oprf_response_size, oprf_response_size });
                }
            };

            for (size_t thread_idx = 0; thread_idx < task_count; thread_idx++) {
                futures[thread_idx] =
                    tpm.thread_pool().enqueue(ProcessQueriesLambda, thread_idx, task_count);
            }

            for (auto &f : futures) {
                f.get();
            }

            return oprf_responses;
        }

这段代码实现了 OPRFSender 类的 ProcessQueries 方法,用于处理 OPRF 查询并生成对应的 OPRF 响应。以下是代码的详细解释:

  1. 检查输入的 OPRF 查询的大小是否是 oprf_query_size 的倍数,如果不是则抛出 invalid_argument 异常,表示输入的查询大小无效。

  2. 使用计时器 STOPWATCH 开始计时,用于测量处理查询的时间。

  3. 计算查询数量 query_count,即输入的 OPRF 查询大小除以 oprf_query_size

  4. 创建一个大小为 query_count * oprf_response_sizevector 对象,用于存储生成的 OPRF 响应。

  5. 设置两个指针 oprf_in_ptroprf_out_ptr,分别指向输入查询数据和输出响应数据的起始位置。

  6. 创建线程池管理器 ThreadPoolMgr 对象 tpm,用于管理并发执行的任务。

  7. 计算任务数量 task_count,即线程池中线程的数量和查询数量之间的较小值,确保每个查询都能被一个线程处理。

  8. 创建一个包含 task_countfuture 对象的 vector,用于存储每个线程的异步任务。

  9. 定义一个 lambda 表达式 ProcessQueriesLambda,用于处理查询。该 lambda 表达式接受两个参数 start_idxstep,分别表示起始索引和步长,以便线程间分配查询。
    核心就是这个ProcessQueriesLambda 表达式
    1)加载查询点的数据到ECPoint 椭圆曲线
    2)使用 oprf_key 中的密钥对加载的查询点执行标量乘法操作
    3)将标量乘法的结果保存到输出缓冲区中

  10. 在每个线程中,使用 thread_pool().enqueue 方法将 lambda 表达式 ProcessQueriesLambda 添加到线程池中执行,并将 start_idxstep 作为参数传递。

  11. 等待所有线程的任务完成,通过调用 f.get() 来获取每个异步任务的结果。

  12. 返回生成的 OPRF 响应数据。

 二、从接收到的 OPRF 响应中提取哈希项和标签密钥


        OPRFResponse oprf_response = to_oprf_response(chl.receive_response());
        vector<HashedItem> hashed_recv_items;
        vector<LabelKey> label_keys;
        tie(hashed_recv_items, label_keys) =
            Receiver::ExtractHashes(oprf_response, oprf_receiver);
        ASSERT_EQ(hashed_recv_items.size(), recv_items.size());
  1. 从通道 chl 中接收 OPRF 响应,并使用 to_oprf_response 函数将其转换为 OPRFResponse 类型的对象 oprf_response

  2. 创建一个空的哈希项向量 hashed_recv_items,用于存储从 OPRF 响应中提取的哈希项。

  3. 创建一个空的标签密钥向量 label_keys,用于存储从 OPRF 响应中提取的标签密钥。

  4. 使用 Receiver::ExtractHashes 函数从 OPRF 响应和 OPRF 接收器中提取哈希项和标签密钥,并将它们分别存储到 hashed_recv_itemslabel_keys 中。这里使用 tie 函数可以将多个返回值捆绑在一起。
    重点是Receiver::ExtractHashes中的OPRFReceiver::process_responses方法

            void OPRFReceiver::process_responses(
                gsl::span<const unsigned char> oprf_responses,
                gsl::span<HashedItem> oprf_hashes,
                gsl::span<LabelKey> label_keys) const
            {
                if (oprf_hashes.size() != item_count()) {
                    throw invalid_argument("oprf_hashes has invalid size");
                }
                if (label_keys.size() != item_count()) {
                    throw invalid_argument("label_keys has invalid size");
                }
                if (oprf_responses.size() != item_count() * oprf_response_size) {
                    throw invalid_argument("oprf_responses size is incompatible with oprf_hashes size");
                }
    
                auto oprf_in_ptr = oprf_responses.data();
                for (size_t i = 0; i < item_count(); i++) {
                    // Load the point from items_buffer
                    ECPoint ecpt;
                    ecpt.load(ECPoint::point_save_span_const_type{ oprf_in_ptr, oprf_response_size });
    
                    // Multiply with inverse random scalar
                    ecpt.scalar_multiply(inv_factor_data_.get_factor(i), false);
    
                    // Extract the item hash and the label encryption key
                    array<unsigned char, ECPoint::hash_size> item_hash_and_label_key;
                    ecpt.extract_hash(item_hash_and_label_key);
    
                    // The first 16 bytes represent the item hash; the next 32 bytes represent the label
                    // encryption key
                    copy_bytes(
                        item_hash_and_label_key.data(), oprf_hash_size, oprf_hashes[i].value().data());
                    copy_bytes(
                        item_hash_and_label_key.data() + oprf_hash_size,
                        label_key_byte_count,
                        label_keys[i].data());
    
                    // Move forward
                    advance(oprf_in_ptr, oprf_response_size);
                }
            }
        } // namespace oprf
    } // namespace apsi

    方法重点是循环部分
    1)
    加载 OPRF 响应数据到临时的 ECPoint 对象 ecpt
    2)使用随机标量的逆进行标量乘法操作。
    3)从标量乘法的结果中提取项目哈希值和标签密钥,并复制到相应的 oprf_hasheslabel_keys 中。

  5. ASSERT_EQ(hashed_recv_items.size(), recv_items.size());: 断言验证提取的哈希项的数量与接收到的项目数量 recv_items.size() 相等,确保提取的哈希项数量正确。

 三、创建查询请求对象,根据需要切换压缩模式,将查询请求发送到通道中。

                // Create query and send
                pair<Request, IndexTranslationTable> recv_query_pair =
                    receiver.create_query(hashed_recv_items);

                QueryRequest recv_query = to_query_request(move(recv_query_pair.first));
                compr_mode_type expected_compr_mode = recv_query->compr_mode;

                if (use_different_compression &&
                    Serialization::IsSupportedComprMode(compr_mode_type::zlib) &&
                    Serialization::IsSupportedComprMode(compr_mode_type::zstd)) {
                    if (recv_query->compr_mode == compr_mode_type::zstd) {
                        recv_query->compr_mode = compr_mode_type::zlib;
                        expected_compr_mode = compr_mode_type::zlib;
                    } else {
                        recv_query->compr_mode = compr_mode_type::zstd;
                        expected_compr_mode = compr_mode_type::zstd;
                    }
                }

                IndexTranslationTable itt = move(recv_query_pair.second);
                chl.send(move(recv_query));
  1. 使用哈希项创建查询并发送:首先,调用 receiver.create_query(hashed_recv_items) 创建查询和索引翻译表的对,并将结果存储在 recv_query_pair 中。

  2. 将查询请求对象转换为查询请求:使用 to_query_request(move(recv_query_pair.first)) 将查询请求对象从 recv_query_pair 中提取出来,并将其转换为 QueryRequest 对象,存储在 recv_query 中。

  3. 设置预期的压缩模式:如果需要使用不同的压缩模式且系统支持 zlib 和 zstd 压缩模式,则检查接收到的查询的压缩模式是否为 zstd。如果是,将压缩模式切换为 zlib,并更新预期的压缩模式为 zlib;如果不是,则将压缩模式切换为 zstd,并更新预期的压缩模式为 zstd。

  4. 移动索引翻译表并发送查询:将索引翻译表移动到 itt 中,并使用 chl.send(move(recv_query)) 将查询发送到通道中。

四、接收查询操作并处理响应,接收查询响应提取其中的包数量


                // Receive the query and process response
                QueryRequest sender_query =                     
                to_query_request(chl.receive_operation(seal_context));
                Query query(move(sender_query), sender_db);
                ASSERT_EQ(expected_compr_mode, query.compr_mode());
                ASSERT_NO_THROW(Sender::RunQuery(query, chl));

                // Receive query response
                QueryResponse query_response = to_query_response(chl.receive_response());
                uint32_t package_count = query_response->package_count;
  1. 接收查询并处理响应:使用 chl.receive_operation(seal_context) 从通道中接收查询操作,并将其转换为查询请求对象 QueryRequest。然后,使用 Query 类构造函数创建查询对象 query,并将其移动到 query 中。接着,使用断言 ASSERT_EQ(expected_compr_mode, query.compr_mode()) 确保查询对象的压缩模式与预期的压缩模式相等。最后,调用 Sender::RunQuery 函数执行查询,并通过通道 chl 传递查询对象。

  2. 接收查询响应:使用 chl.receive_response() 从通道中接收查询响应,并使用 to_query_response 函数将其转换为查询响应对象 QueryResponse。然后,提取查询响应中的包数量 package_count。

五、接收所有结果部分、处理查询结果

            
              // Receive all result parts and process result
                vector<ResultPart> rps;
                while (package_count--) {
                    ASSERT_NO_THROW(rps.push_back(chl.receive_result(receiver.get_seal_context())));
                }
                auto query_result = receiver.process_result(label_keys, itt, rps);

                verify_labeled_results(query_result, recv_items, recv_int_items, sender_items);
  1. 接收所有结果部分并处理结果:通过一个循环,逐个接收所有结果部分,并将它们存储在名为 rps 的结果部分向量中。每次循环迭代时,调用 chl.receive_result(receiver.get_seal_context()) 从通道 chl 中接收一个结果部分,并使用 receiver.get_seal_context() 获取接收方的密封上下文。接收结果时,使用 ASSERT_NO_THROW 确保接收操作没有异常发生。

  2. 处理查询结果:使用接收到的结果部分向量 rps、标签密钥向量 label_keys 和索引翻译表 itt,调用 receiver.process_result() 函数处理查询结果,并将处理后的结果存储在 query_result 变量中。

  3. 验证未标记的结果:使用 verify_unlabeled_results 函数验证处理后的查询结果是否正确,参数包括处理后的查询结果 query_result、接收到的项目 recv_items 以及接收到的感兴趣项目 recv_int_items

  • 12
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值