一、总体流程:
总体流程为:
1.参数设定:对基础的一些参数进行设定,例如多项式模、明文系数模等;
2.数据库预处理:将数据库进行多项式展开处理,不在以实质处理;
3.客户端发起single query请求:生成index,封装后查询;
4.服务器端返回reply query回应:返回查询结果;
5.客户端进行decode reply:客户端针对返回的值进行处理,的到最后所需要的值;
二、基础参数设定:
主要进行以下参数设定:
1.数据库设定:
Demo中给定的是 2 18 2^{18} 218条element数据,每条数据为1024个单元的数组,每个单元为uint8_t类型,也就是8位无符号数;
2.多项式模:
即论文中的N,这里采取N=4096;
3.明文系数模:
明文的每个系数的最大模数,在程序中常采用最大二进制位数对系数位数进行约束,这里logt=20;
4.hyper-cube维数设定:
也就是把一维坐标拓展到多维度的维度个数,通常选取d=2;
也就是整体数据分布为一个矩阵;
5.密钥以及其他参数设定:
这里主要设定是否采取batch和对称加密设定,具体行为取决于实际需求和Seal库,这里后续学习Seal库补一下;
bool use_symmetric = true; // use symmetric encryption instead of public key (recommended for smaller query)
bool use_batching = true; // pack as many elements as possible into a BFV plaintext (recommended)
bool use_recursive_mod_switching = true;
三、参数设定:
对基础参数进行设定;
入口函数:
gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params, use_symmetric, use_batching, use_recursive_mod_switching);
这里先注意一下数据库的表示形式;
论文中主要阐述了,如何对数据库处理变为hyper-cube以及batch编码形式;
对于hyper-cube,可以理解为直接对索引进行降级;
对于d条数据,索引下标形式为0~d-1,此时需要的的二进制位数位 2 d 2^d 2d位才可以进行检索;
如果按照二维进行切分,为一个方阵,便可以通过横纵坐标进行访问,此时所需要表达的二进制位数只有 2 ∗ 2 d 2*2^{\sqrt d} 2∗2d位,此时可以看到坐标的空间得到了极大的节省;
对于batch编码,则是在hyper-cube的基础上在进行压缩;
一般情况下,一个plaintext应该存储一个数据,但是如果将plaintext存储多个数据,就可以使得 d ∗ d \sqrt d *\sqrt d d∗d的矩阵再次压缩;
如果想要检索一个元素,只需要算出他在第几个plaintext中,按照横纵坐标检索到该plaintext,计算内部偏移,即可找到该元素;
因此基于以上理论,需要计算以下两个参数:
std::uint64_t elements_per_plaintext;//多个element打包到一个明文中;
std::uint64_t num_of_plaintexts;//完整表示数据库需要多少明文;
elements_per_plaintext = elements_per_ptxt(logt, N, ele_size);//计算一个FV plaintext能够表示多少个转为多项式系数的element;
num_of_plaintexts = plaintexts_per_db(logt, N, ele_num, ele_size);//计算整体database需要多少个FV plaintext进行表示;
其中注意一下基本参数和elements_per_ptxt等参数的关系;
N位多项式模,意味着明文大小和密文大小不能超过N,因此对于一个plaintext,能够存放多少个取决于每个element的系数个数和N的关系;
其中每个plaintext能够存放element数量由以下计算:
uint64_t elements_per_ptxt(uint32_t logt, uint64_t N, uint64_t ele_size) {
/*
备注一下:对于给定的系数表示每一个多项式的系数应该由logt位表示;
因为每个element的单元位unit_8,因此一个8位数据应该由8/logt个单元进行表示;
又因为element的拥有ele_szie个单元,所以每个element的需要的系数位数应该为ceil(8 * ele_size / (double)logt)个
*/
uint64_t coeff_per_ele = coefficients_per_element(logt, ele_size);//获取每个element存储系数所需单元;
uint64_t ele_per_ptxt = N / coeff_per_ele;//计算每个FV plaintext能够表示多少个element;
assert(ele_per_ptxt > 0);
return ele_per_ptxt;
}
cofficients_per_element为每个元素用多项式表示,需要用几个系数空间;
// Number of coefficients needed to represent a database element
uint64_t coefficients_per_element(uint32_t logt, uint64_t ele_size) {
return ceil(8 * ele_size / (double)logt);
}
这里说明一下,当时乍一看没看懂;
对每个elemnt的多项式转换的目标是element的每个存储空间;
由于element的每个存储单元为uint_8类型,也就是八位存储空间,当存在logt<8的情况,就必须要要将这8为分开存储,拆成多个系数单元;
因此对于一个element所需系数空间,应该为:(8/logt)*ele_size;
计算处每个明文能够表达多少个element之后,就可以计算出表达整个DB需要多少个plaintext;
之后根据位数d进行每个维度的大小确定:
vector<uint64_t> nvec = get_dimensions(num_of_plaintexts, d);//通过给定的维度获得hyper-cube下的具体维度;
最后,生成密文系数模数组,计算膨胀系数;
这里使用seal库默认生成,并且存储为一个coff_modulus的拆分形式,为何拆分而不是整体记录则是因为防止同态计算导致的噪音大小问题,这里后续补基础之后再说;
一个coff_modulus存有多个密文系数模,因此膨胀系数直接采用分别计算整体加和形式,总体为 l o g q / l o g t log q/log t logq/logt;
uint32_t expansion_ratio = 0;
//coff_modulus返回logq的数组,并且计算膨胀倍数,密文大小/明文大小;
for (uint32_t i = 0; i < enc_params.coeff_modulus().size(); ++i) {
double logqi = log2(enc_params.coeff_modulus()[i].value());
expansion_ratio += ceil(logqi / logt);
}
最后对pir_params参数进行设置:
pir_params.enable_symmetric = enable_symmetric;
pir_params.enable_batching = enable_batching;
pir_params.enable_mswitching = enable_mswitching;
pir_params.ele_num = ele_num;
pir_params.ele_size = ele_size;
pir_params.elements_per_plaintext = elements_per_plaintext;
pir_params.num_of_plaintexts = num_of_plaintexts;
pir_params.d = d;
pir_params.expansion_ratio = expansion_ratio << 1;
pir_params.nvec = nvec;
pir_params.slot_count = N;
四、数据库初始化:
主要将数据转换为多项式系数格式,并且装在进plaintext内;
本样例的数据库生成:
// Create test database
auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
// Copy of the database. We use this at the end to make sure we retrieved
// the correct element.
auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
random_device rd;
for (uint64_t i = 0; i < number_of_items; i++) {
for (uint64_t j = 0; j < size_per_item; j++) {
uint8_t val = rd() % 256;
db.get()[(i * size_per_item) + j] = val;
db_copy.get()[(i * size_per_item) + j] = val;
}
}
之后进行setbase入口函数进行初始化:
server.set_database(move(db), number_of_items, size_per_item);//将数据转为系数多项式形式
server.preprocess_database();//转到NNT域便于后续计算
set_database中,具体操作为:
for (uint64_t i = 0; i < num_of_plaintexts; i++) {
uint64_t process_bytes = 0;
if (db_size <= offset) {
//处理结束
break;
} else if (db_size < offset + bytes_per_ptxt) {
//最后一轮处理的element不满一个plaintext;
process_bytes = db_size - offset;
} else {
process_bytes = bytes_per_ptxt;
}
assert(process_bytes % ele_size == 0);
uint64_t ele_in_chunk = process_bytes / ele_size;//本次处理的element数目;
// Get the coefficients of the elements that will be packed in plaintext i
vector<uint64_t> coefficients(coeff_per_ptxt);
for(uint64_t ele = 0; ele < ele_in_chunk; ele++){
vector<uint64_t> element_coeffs = bytes_to_coeffs(logt, bytes.get() + offset + (ele_size*ele), ele_size);
//将element转化为系数多项式模式;
std::copy(element_coeffs.begin(), element_coeffs.end(), coefficients.begin() + (coefficients_per_element(logt, ele_size) * ele));
}
offset += process_bytes;
uint64_t used = coefficients.size();
assert(used <= coeff_per_ptxt);
// Pad the rest with 1s
for (uint64_t j = 0; j < (pir_params_.slot_count - used); j++) {
coefficients.push_back(1);
//主要针对最后element不满一个plaintext的情况,直接将最后的空位填充1;
}
Plaintext plain;
encoder_->encode(coefficients, plain);
// cout << i << "-th encoded plaintext = " << plain.to_string() << endl;
result->push_back(move(plain));
}
由于最后肯定能生成多个plaintext,所以逐个生成,把一定量的element进行多项式转化,之后塞进plaintext中;
这里的转换操作前面计算coff_per_elemnt的时候说过,当logt<8的时候,把低logt存储在第一个单元,把高logt位存储在第二个单元;
值得注意的是,最后一个plaintext往往无法填充满,所以需要将后续的位置进行填充;
// Pad the rest with 1s
for (uint64_t j = 0; j < (pir_params_.slot_count - used); j++) {
coefficients.push_back(1);
//主要针对最后element不满一个plaintext的情况,直接将最后的空位填充1;
}
之后需要使用encode函数将其装载入明文:
Plaintext plain;
encoder_->encode(coefficients, plain);
所以整体数据库构造为多项式形式完成;
preprocess_database主要目的是转化为NNT域,目前的基础无法理解这部分的含义;
五、客户端生成请求:
随机生成请求,并且计算所在的plaintext的位置和内部偏移:
// Choose an index of an element in the DB
uint64_t ele_index = rd() % number_of_items; // element in DB at random position
uint64_t index = client.get_fv_index(ele_index); // index of FV plaintext 得到查询element位于plaintext的所在索引
uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext 得到所在plaintext的所在偏移;
cout << "Main: element index = " << ele_index << " from [0, " << number_of_items -1 << "]" << endl;
cout << "Main: FV index = " << index << ", FV offset = " << offset << endl;
入口函数为:
PirQuery query = client.generate_query(index);//生成查询query
这里又存在分块的问题,之前论文阐述过;
对于plaintext,查询索引不能大于N(之所以不大于N是因需要用plaintext系数来表达选取位数,因此系数大小为选择的个数);
换句话说,对于大于N个明文,想要进行检索必须要分成多块,每块为N;
因此,对于给定的index,应该按快进行计算,所查询plaintext不在的块将查询值置为0,所在的块将起置为内部的偏置,以便后续的同态扩充计算;
这里注意一下,plaintext为大小N的多项式,每个系数相当于一个元素,如果想选取第i个元素,则为 x i x^i xi的系数为1,因此为plaintext[i]=1;
返回值为PirQuery,为一个装有Ciphertext的二维矩阵,维度取决于hypercube的维度以及分割的块数;
indices_ = compute_indices(desiredIndex, pir_params_.nvec);//所需plaintext在hyper-cube的索引;
PirQuery result(pir_params_.d);//Query查询每个维度一个查询密文;
int N = enc_params_.poly_modulus_degree();
for (uint32_t i = 0; i < indices_.size(); i++) {
//针对于每个维度开始计算;
uint32_t num_ptxts = ceil( (pir_params_.nvec[i] + 0.0) / N);//每个维度需要几个N大小的plaintext块;
// initialize result.
cout << "Client: index " << i + 1 << "/ " << indices_.size() << " = " << indices_[i] << endl;
cout << "Client: number of ctxts needed for query = " << num_ptxts << endl;
for (uint32_t j =0; j < num_ptxts; j++){
pt.set_zero();
if (indices_[i] >= N*j && indices_[i] <= N*(j+1)){
//如果所需的plaintext在该块中;
uint64_t real_index = indices_[i] - N*j; //在该块钟的真实索引;
uint64_t n_i = pir_params_.nvec[i];
uint64_t total = N;
if (j == num_ptxts - 1){
total = n_i % N;
}
uint64_t log_total = ceil(log2(total));
/*
和最初版本不同,以前是直接赋1,现在则是计算了一次xgcd
这点存疑,不知道为什么要这样,是否会影响拓展后的同态计算的正确性;
*/
cout << "Client: Inverting " << pow(2, log_total) << endl;
pt[real_index]=invert_mod(pow(2, log_total), enc_params_.plain_modulus());
}
Ciphertext dest;
if(pir_params_.enable_symmetric){
encryptor_->encrypt_symmetric(pt, dest);
}
else{
encryptor_->encrypt(pt, dest);
}
result[i].push_back(dest);
}
}
上述是针对于每个维度进行计算;
但是值得注意的是和之前的MS原始版本不同,不是采用pt[real_index]=1,而是:
pt[real_index]=invert_mod(pow(2, log_total), enc_params_.plain_modulus());
之后使用加密封装为PirQuery,返回;
这个版本还添加了序列化操作;
即利用stringstream进行封装,入口函数为:
int query_size = client.generate_serialized_query(index, client_stream);//在上述基础加了一步序列化;
六、服务器返回结果:
论文里面的大致流程为:
针对于二维d=2例子,对于两个到达的索引,可以拓充为两个
1
∗
n
1*\sqrt n
1∗n的向量,作为横纵坐标的掩码;
首先利用行掩码找出某一行,在利用列掩码找出某一列,期间找的方式按照同态加和来决定;
入口函数为:
PirReply PIRServer::generate_reply(PirQuery &query, uint32_t client_id, PIRClient& client)
这里注意一个decompstion的问题;
if (i == nvec.size() - 1) {
return intermediateCtxts;
} else {
intermediate_plain.clear();
intermediate_plain.reserve(pir_params_.expansion_ratio * product);
cur = &intermediate_plain;
for (uint64_t rr = 0; rr < product; rr++) {
EncryptionParameters parms;
if(pir_params_.enable_mswitching){
evaluator_->mod_switch_to_inplace(intermediateCtxts[rr], context_->last_parms_id());
parms = context_->last_context_data()->parms();
}
else{
parms = context_->first_context_data()->parms();
}
vector<Plaintext> plains = decompose_to_plaintexts(parms,
intermediateCtxts[rr]);
for (uint32_t jj = 0; jj < plains.size(); jj++) {
intermediate_plain.emplace_back(plains[jj]);
}
}
product = intermediate_plain.size(); // multiply by expansion rate.
}
在同态计算中,可能会存在因为噪音累乘导致噪音过大的问题,因此这里针对于密文做了一个分解操作;
将密文分解为多个coff,并且在下一个维度计算后按照上述方法进行同态掩码乘计算;
对于每个维度,对会将挑选出来的一行或者一列进行扩充,所以最后的单个ciphertext的实际长度和d的个数有关;
最后返回所需要的ciphertext结果;
七、客户端解码解密:
对于服务器端返回PirReply,进行decode和解密;
decode主要是因为Server端进行密文系数扩充,需要缩放回去;
相当于每个维度按照之前的反次数缩放回去;
但是这里值得注意的是,进行compose操作的时候,需要先decrypt,之后由plaintext拼成一个ciphertext,进行下一轮解密;
这块属实没有看懂,不知道为什么decompose和compose为什么可以这样;
for (uint32_t i = 0; i < recursion_level; i++) {
cout << "Client: " << i + 1 << "/ " << recursion_level << "-th decryption layer started." << endl;
vector<Ciphertext> newtemp;
vector<Plaintext> tempplain;
for (uint32_t j = 0; j < temp.size(); j++) {
Plaintext ptxt;
decryptor_->decrypt(temp[j], ptxt);
#ifdef DEBUG
cout << "Client: reply noise budget = " << decryptor_->invariant_noise_budget(temp[j]) << endl;
#endif
//cout << "decoded (and scaled) plaintext = " << ptxt.to_string() << endl;
tempplain.push_back(ptxt);
#ifdef DEBUG
cout << "recursion level : " << i << " noise budget : ";
cout << decryptor_->invariant_noise_budget(temp[j]) << endl;
#endif
if ((j + 1) % (exp_ratio * ciphertext_size) == 0 && j > 0) {
// Combine into one ciphertext.
Ciphertext combined(*context_, parms_id);
compose_to_ciphertext(parms, tempplain, combined);
newtemp.push_back(combined);
tempplain.clear();
// cout << "Client: const term of ciphertext = " << combined[0] << endl;
}
}
cout << "Client: done." << endl;
cout << endl;
if (i == recursion_level - 1) {
assert(temp.size() == 1);
return tempplain[0];
} else {
tempplain.clear();
temp = newtemp;
}
}
之后返回所需要的plaintext整体,由自己持有的offset选取需要哪一个元素;