2021SC@SDUSC
2021-12-26
第十三周完成事项
工作内容
本周的工作是分析剩余数系统的最后一块代码(RNSTool)。首先回顾一下之前了解学习过的剩余数系统。
RNS——剩余数系统
与冗余数相反,剩余数表示系统(RNS,residue number system)是一种用较少的数表示较多的数的表示系统。RNS可显著提高信号处理应用中某些算法密集型场景的算法速度。此外,RNS也是研究快速算法极限理论的一个工具。
剩余数系统的作用是将原本比较大的数,使用多个相对小的数进行表示,这样在运算的时候,可以不必对原本比较大的数进行运算,对多个相对小的数进行运算可以加快运算速度。
代码分析
上周分析了RNSbase类和BaseConverter类,这周我们来看一下RNS剩余数系统最后一个类——RNSTool类。
首先先分析一下内容组成(内容分析将在后续展开)。
让我们先来看一下类中的公共部分。
class RNSTool
{
public:
/**
@throws std::invalid_argument if poly_modulus_degree is out of range, coeff_modulus is not valid, or pool is
invalid.
@throws std::logic_error if coeff_modulus and extended bases do not support NTT or are not coprime.
*/
RNSTool(
std::size_t poly_modulus_degree, const RNSBase &coeff_modulus, const Modulus &plain_modulus,
MemoryPoolHandle pool);
/**
@param[in] input Must be in RNS form, i.e. coefficient must be less than the associated modulus.
*/
void divide_and_round_q_last_inplace(RNSIter input, MemoryPoolHandle pool) const;
void divide_and_round_q_last_ntt_inplace(
RNSIter input, ConstNTTTablesIter rns_ntt_tables, MemoryPoolHandle pool) const;
/**
Shenoy-Kumaresan conversion from Bsk to q
*/
void fastbconv_sk(ConstRNSIter input, RNSIter destination, MemoryPoolHandle pool) const;
/**
Montgomery reduction mod q; changes base from Bsk U {m_tilde} to Bsk
*/
void sm_mrq(ConstRNSIter input, RNSIter destination, MemoryPoolHandle pool) const;
/**
Divide by q and fast floor from q U Bsk to Bsk
*/
void fast_floor(ConstRNSIter input, RNSIter destination, MemoryPoolHandle pool) const;
/**
Fast base conversion from q to Bsk U {m_tilde}
*/
void fastbconv_m_tilde(ConstRNSIter input, RNSIter destination, MemoryPoolHandle pool) const;
/**
Compute round(t/q * |input|_q) mod t exactly
*/
void decrypt_scale_and_round(ConstRNSIter phase, CoeffIter destination, MemoryPoolHandle pool) const;
公共的第一部分包括了RNSTool类的构造函数以及主要的功能型函数。
SEAL_NODISCARD inline auto inv_q_last_mod_q() const noexcept
{
return inv_q_last_mod_q_.get();
}
SEAL_NODISCARD inline auto base_Bsk_ntt_tables() const noexcept
{
return base_Bsk_ntt_tables_.get();
}
SEAL_NODISCARD inline auto base_q() const noexcept
{
return base_q_.get();
}
SEAL_NODISCARD inline auto base_B() const noexcept
{
return base_B_.get();
}
SEAL_NODISCARD inline auto base_Bsk() const noexcept
{
return base_Bsk_.get();
}
SEAL_NODISCARD inline auto base_Bsk_m_tilde() const noexcept
{
return base_Bsk_m_tilde_.get();
}
SEAL_NODISCARD inline auto base_t_gamma() const noexcept
{
return base_t_gamma_.get();
}
SEAL_NODISCARD inline auto &m_tilde() const noexcept
{
return m_tilde_;
}
SEAL_NODISCARD inline auto &m_sk() const noexcept
{
return m_sk_;
}
SEAL_NODISCARD inline auto &t() const noexcept
{
return t_;
}
SEAL_NODISCARD inline auto &gamma() const noexcept
{
return gamma_;
}
公共的第二部分是对类中部分私有变量的获取。
private:
RNSTool(const RNSTool ©) = delete;
RNSTool(RNSTool &&source) = delete;
RNSTool &operator=(const RNSTool &assign) = delete;
RNSTool &operator=(RNSTool &&assign) = delete;
/**
Generates the pre-computations for the given parameters.
*/
void initialize(std::size_t poly_modulus_degree, const RNSBase &q, const Modulus &t);
接下来分析私有部分的函数和变量,私有的第一部分是复制构造函数和赋值构造函数以及RNSTool类的初始化函数。这里的复制构造函数和赋值构造函数也是一样的,delete删除掉了,防止对其操作过程造成影响。
MemoryPoolHandle pool_;
std::size_t coeff_count_ = 0;
Pointer<RNSBase> base_q_;
Pointer<RNSBase> base_B_;
Pointer<RNSBase> base_Bsk_;
Pointer<RNSBase> base_Bsk_m_tilde_;
Pointer<RNSBase> base_t_gamma_;
// Base converter: q --> B_sk
Pointer<BaseConverter> base_q_to_Bsk_conv_;
// Base converter: q --> {m_tilde}
Pointer<BaseConverter> base_q_to_m_tilde_conv_;
// Base converter: B --> q
Pointer<BaseConverter> base_B_to_q_conv_;
// Base converter: B --> {m_sk}
Pointer<BaseConverter> base_B_to_m_sk_conv_;
// Base converter: q --> {t, gamma}
Pointer<BaseConverter> base_q_to_t_gamma_conv_;
// prod(q)^(-1) mod Bsk
Pointer<MultiplyUIntModOperand> inv_prod_q_mod_Bsk_;
// prod(q)^(-1) mod m_tilde
MultiplyUIntModOperand neg_inv_prod_q_mod_m_tilde_;
// prod(B)^(-1) mod m_sk
MultiplyUIntModOperand inv_prod_B_mod_m_sk_;
// gamma^(-1) mod t
MultiplyUIntModOperand inv_gamma_mod_t_;
// prod(B) mod q
Pointer<std::uint64_t> prod_B_mod_q_;
// m_tilde^(-1) mod Bsk
Pointer<MultiplyUIntModOperand> inv_m_tilde_mod_Bsk_;
// prod(q) mod Bsk
Pointer<std::uint64_t> prod_q_mod_Bsk_;
// -prod(q)^(-1) mod {t, gamma}
Pointer<MultiplyUIntModOperand> neg_inv_q_mod_t_gamma_;
// prod({t, gamma}) mod q
Pointer<MultiplyUIntModOperand> prod_t_gamma_mod_q_;
// q[last]^(-1) mod q[i] for i = 0..last-1
Pointer<MultiplyUIntModOperand> inv_q_last_mod_q_;
// NTTTables for Bsk
Pointer<NTTTables> base_Bsk_ntt_tables_;
Modulus m_tilde_;
Modulus m_sk_;
Modulus t_;
Modulus gamma_;
私有的第二部分是定义了一些类中所用到的变量和指针,在其中我们可以看到之前分析的RNSbase类和BaseConverter类的相对应的指针,由此可以判断RNSTool类相当于一个对外的工具,使用RNSTool类可以使用RNSbase类和BaseConverter类的相关功能。
让我们来看一下最基本的初始化函数,因为篇幅比较长,拆成了多段分析,便于阅读。
void RNSTool::initialize(size_t poly_modulus_degree, const RNSBase &q, const Modulus &t)
{
// Return if q is out of bounds
if (q.size() < SEAL_COEFF_MOD_COUNT_MIN || q.size() > SEAL_COEFF_MOD_COUNT_MAX)
{
throw invalid_argument("rnsbase is invalid");
}
首先判断传进来的RNSbase类的参数q大小是否符合seal系数模的标准,如果不符合则抛出异常,这里的SEAL_COEFF_MOD_COUNT_MIN
和SEAL_COEFF_MOD_COUNT_MAX
均为是事先定义好的常数。
// Return if coeff_count is not a power of two or out of bounds
int coeff_count_power = get_power_of_two(poly_modulus_degree);
if (coeff_count_power < 0 || poly_modulus_degree > SEAL_POLY_MOD_DEGREE_MAX ||
poly_modulus_degree < SEAL_POLY_MOD_DEGREE_MIN)
{
throw invalid_argument("poly_modulus_degree is invalid");
}
t_ = t;
coeff_count_ = poly_modulus_degree;
/**
If the value is a power of two, return the power; otherwise, return -1.
*/
SEAL_NODISCARD inline int get_power_of_two(std::uint64_t value)
{
if (value == 0 || (value & (value - 1)) != 0)
{
return -1;
}
unsigned long result = 0;
SEAL_MSB_INDEX_UINT64(&result, value);
return static_cast<int>(result);
}
第二步判断传进来的多项式模的度数是否符合规范,这里判断了一下这个度数是否为二的幂(get_power_of_two函数展示如上,如果是二的幂,则返回幂;如果不是则返回-1)。然后根据计算出来的幂coeff_count_power和多项式模的度数进行判断,如果度数不为二的幂(返回值为-1)或者多项式模的度数超过了seal多项式模度数的限度,则抛出异常,如果满足要求则将传进来的模参数t赋给类中的私有变量t_,把多项式模的系数赋给类中的私有变量coeff_count_。这里的SEAL_POLY_MOD_DEGREE_MAX
和SEAL_POLY_MOD_DEGREE_MIN
同样也为初始定义的常数。
// Allocate memory for the bases q, B, Bsk, Bsk U m_tilde, t_gamma
size_t base_q_size = q.size();
// In some cases we might need to increase the size of the base B by one, namely we require
// K * n * t * q^2 < q * prod(B) * m_sk, where K takes into account cross terms when larger size ciphertexts
// are used, and n is the "delta factor" for the ring. We reserve 32 bits for K * n. Here the coeff modulus
// primes q_i are bounded to be SEAL_USER_MOD_BIT_COUNT_MAX (60) bits, and all primes in B and m_sk are
// SEAL_INTERNAL_MOD_BIT_COUNT (61) bits.
int total_coeff_bit_count = get_significant_bit_count_uint(q.base_prod(), q.size());
size_t base_B_size = base_q_size;
if (32 + t_.bit_count() + total_coeff_bit_count >=
SEAL_INTERNAL_MOD_BIT_COUNT * safe_cast<int>(base_q_size) + SEAL_INTERNAL_MOD_BIT_COUNT)
{
base_B_size++;
}
size_t base_Bsk_size = add_safe(base_B_size, size_t(1));
size_t base_Bsk_m_tilde_size = add_safe(base_Bsk_size, size_t(1));
size_t base_t_gamma_size = 0;
SEAL_NODISCARD inline int get_significant_bit_count_uint(const std::uint64_t *value, std::size_t uint64_count)
{
#ifdef SEAL_DEBUG
if (!value && uint64_count)
{
throw std::invalid_argument("value");
}
if (!uint64_count)
{
throw std::invalid_argument("uint64_count");
}
#endif
value += uint64_count - 1;
for (; *value == 0 && uint64_count > 1; uint64_count--)
{
value--;
}
return static_cast<int>(uint64_count - 1) * bits_per_uint64 + get_significant_bit_count(*value);
}
template <
typename T, typename S, typename = std::enable_if_t<std::is_arithmetic<T>::value>,
typename = std::enable_if_t<std::is_arithmetic<S>::value>>
SEAL_NODISCARD inline T safe_cast(S value)
{
SEAL_IF_CONSTEXPR(!std::is_same<T, S>::value)
{
if (!fits_in<T>(value))
{
throw std::logic_error("cast failed");
}
}
return static_cast<T>(value);
}
通过RNSbase变量q,定义了一些基础变量,为后续工作做准备。这里针对base_B_size进行了分析,如果32加上之前输入的模数t的bit位加上整个q主要的bit位大于base_q_size的safe_cast加一乘以迭代器模数bit位时,会将base_B_size加一。
对于get_significant_bit_count_uint
本质上就是将原本的value值,以uint64的bit位为单位进行拆分,逐步即可得到原本数据的bit位数。
// Size check
if (!product_fits_in(coeff_count_, base_Bsk_m_tilde_size))
{
throw logic_error("invalid parameters");
}
// Sample primes for B and two more primes: m_sk and gamma
auto baseconv_primes = get_primes(coeff_count_, SEAL_INTERNAL_MOD_BIT_COUNT, base_Bsk_m_tilde_size);
auto baseconv_primes_iter = baseconv_primes.cbegin();
m_sk_ = *baseconv_primes_iter++;
gamma_ = *baseconv_primes_iter++;
vector<Modulus> base_B_primes;
copy_n(baseconv_primes_iter, base_B_size, back_inserter(base_B_primes));
下一步,检查系数数和上一步生成的base_Bsk_m_tilde_size是否符合要求,不符合则抛出异常。通过上述的两个变量以及SEAL_INTERNAL_MOD_BIT_COUNT
得到一个素数,再通过素数baseconv_primes
的cbegin方法得到素数的迭代器,依次为m_sk_和gamma_进行赋值,构造一个模数的向量,调用拷贝函数进行操作。
// Set m_tilde_ to a non-prime value
m_tilde_ = uint64_t(1) << 32;
// Populate the base arrays
base_q_ = allocate<RNSBase>(pool_, q, pool_);
base_B_ = allocate<RNSBase>(pool_, base_B_primes, pool_);
base_Bsk_ = allocate<RNSBase>(pool_, base_B_->extend(m_sk_));
base_Bsk_m_tilde_ = allocate<RNSBase>(pool_, base_Bsk_->extend(m_tilde_));
// Set up t-gamma base if t_ is non-zero (using BFV)
if (!t_.is_zero())
{
base_t_gamma_size = 2;
base_t_gamma_ = allocate<RNSBase>(pool_, vector<Modulus>{ t_, gamma_ }, pool_);
}
// Generate the Bsk NTTTables; these are used for NTT after base extension to Bsk
try
{
CreateNTTTables(
coeff_count_power, vector<Modulus>(base_Bsk_->base(), base_Bsk_->base() + base_Bsk_size),
base_Bsk_ntt_tables_, pool_);
}
catch (const logic_error &)
{
throw logic_error("invalid rns bases");
}
// Set up BaseConverter for q --> Bsk
base_q_to_Bsk_conv_ = allocate<BaseConverter>(pool_, *base_q_, *base_Bsk_, pool_);
// Set up BaseConverter for q --> {m_tilde}
base_q_to_m_tilde_conv_ = allocate<BaseConverter>(pool_, *base_q_, RNSBase({ m_tilde_ }, pool_), pool_);
// Set up BaseConverter for B --> q
base_B_to_q_conv_ = allocate<BaseConverter>(pool_, *base_B_, *base_q_, pool_);
// Set up BaseConverter for B --> {m_sk}
base_B_to_m_sk_conv_ = allocate<BaseConverter>(pool_, *base_B_, RNSBase({ m_sk_ }, pool_), pool_);
if (base_t_gamma_)
{
// Set up BaseConverter for q --> {t, gamma}
base_q_to_t_gamma_conv_ = allocate<BaseConverter>(pool_, *base_q_, *base_t_gamma_, pool_);
}
接下来设置一下下面模计算的参数准备,创造一个ntt(快速数论变换)的表,加快计算。
// Compute prod(B) mod q
prod_B_mod_q_ = allocate_uint(base_q_size, pool_);
SEAL_ITERATE(iter(prod_B_mod_q_, base_q_->base()), base_q_size, [&](auto I) {
get<0>(I) = modulo_uint(base_B_->base_prod(), base_B_size, get<1>(I));
});
uint64_t temp;
// Compute prod(q)^(-1) mod Bsk
inv_prod_q_mod_Bsk_ = allocate<MultiplyUIntModOperand>(base_Bsk_size, pool_);
for (size_t i = 0; i < base_Bsk_size; i++)
{
temp = modulo_uint(base_q_->base_prod(), base_q_size, (*base_Bsk_)[i]);
if (!try_invert_uint_mod(temp, (*base_Bsk_)[i], temp))
{
throw logic_error("invalid rns bases");
}
inv_prod_q_mod_Bsk_[i].set(temp, (*base_Bsk_)[i]);
}
// Compute prod(B)^(-1) mod m_sk
temp = modulo_uint(base_B_->base_prod(), base_B_size, m_sk_);
if (!try_invert_uint_mod(temp, m_sk_, temp))
{
throw logic_error("invalid rns bases");
}
inv_prod_B_mod_m_sk_.set(temp, m_sk_);
// Compute m_tilde^(-1) mod Bsk
inv_m_tilde_mod_Bsk_ = allocate<MultiplyUIntModOperand>(base_Bsk_size, pool_);
SEAL_ITERATE(iter(inv_m_tilde_mod_Bsk_, base_Bsk_->base()), base_Bsk_size, [&](auto I) {
if (!try_invert_uint_mod(barrett_reduce_64(m_tilde_.value(), get<1>(I)), get<1>(I), temp))
{
throw logic_error("invalid rns bases");
}
get<0>(I).set(temp, get<1>(I));
});
// Compute prod(q)^(-1) mod m_tilde
temp = modulo_uint(base_q_->base_prod(), base_q_size, m_tilde_);
if (!try_invert_uint_mod(temp, m_tilde_, temp))
{
throw logic_error("invalid rns bases");
}
neg_inv_prod_q_mod_m_tilde_.set(negate_uint_mod(temp, m_tilde_), m_tilde_);
// Compute prod(q) mod Bsk
prod_q_mod_Bsk_ = allocate_uint(base_Bsk_size, pool_);
SEAL_ITERATE(iter(prod_q_mod_Bsk_, base_Bsk_->base()), base_Bsk_size, [&](auto I) {
get<0>(I) = modulo_uint(base_q_->base_prod(), base_q_size, get<1>(I));
});
if (base_t_gamma_)
{
// Compute gamma^(-1) mod t
if (!try_invert_uint_mod(barrett_reduce_64(gamma_.value(), t_), t_, temp))
{
throw logic_error("invalid rns bases");
}
inv_gamma_mod_t_.set(temp, t_);
// Compute prod({t, gamma}) mod q
prod_t_gamma_mod_q_ = allocate<MultiplyUIntModOperand>(base_q_size, pool_);
SEAL_ITERATE(iter(prod_t_gamma_mod_q_, base_q_->base()), base_q_size, [&](auto I) {
get<0>(I).set(
multiply_uint_mod((*base_t_gamma_)[0].value(), (*base_t_gamma_)[1].value(), get<1>(I)),
get<1>(I));
});
// Compute -prod(q)^(-1) mod {t, gamma}
neg_inv_q_mod_t_gamma_ = allocate<MultiplyUIntModOperand>(base_t_gamma_size, pool_);
SEAL_ITERATE(iter(neg_inv_q_mod_t_gamma_, base_t_gamma_->base()), base_t_gamma_size, [&](auto I) {
get<0>(I).operand = modulo_uint(base_q_->base_prod(), base_q_size, get<1>(I));
if (!try_invert_uint_mod(get<0>(I).operand, get<1>(I), get<0>(I).operand))
{
throw logic_error("invalid rns bases");
}
get<0>(I).set(negate_uint_mod(get<0>(I).operand, get<1>(I)), get<1>(I));
});
}
// Compute q[last]^(-1) mod q[i] for i = 0..last-1
// This is used by modulus switching and rescaling
inv_q_last_mod_q_ = allocate<MultiplyUIntModOperand>(base_q_size - 1, pool_);
SEAL_ITERATE(iter(inv_q_last_mod_q_, base_q_->base()), base_q_size - 1, [&](auto I) {
if (!try_invert_uint_mod((*base_q_)[base_q_size - 1].value(), get<1>(I), temp))
{
throw logic_error("invalid rns bases");
}
get<0>(I).set(temp, get<1>(I));
});
}
最后是一些模的计算,这里不再过多赘述,选取其中进行简单分析。
// Compute prod(B) mod q
prod_B_mod_q_ = allocate_uint(base_q_size, pool_);
SEAL_ITERATE(iter(prod_B_mod_q_, base_q_->base()), base_q_size, [&](auto I) {
get<0>(I) = modulo_uint(base_B_->base_prod(), base_B_size, get<1>(I));
});
SEAL_NODISCARD inline auto allocate_uint(std::size_t uint64_count, MemoryPool &pool)
{
return allocate<std::uint64_t>(uint64_count, pool);
}
先进行内存分配,获得内存指针,通过seal_iterate迭代操作完成模的计算,其中第一个参数是我们需要进行操作的变量,第二个参数是操作时需要使用的变量,第三个为操作的具体执行步骤的公式,在公式里面get<0>(I)
对应的是第一个参数,get<1>(I)
对应的是第二个参数。
至此我们的RNS剩余数系统基本全部分析完成。
总结
时间已经到了这学期课程的尾声,这学期我和两位队友选择了孔老师的SEAL全同态加密开源库的这个方向,从一开始的初出茅庐、一窍不通到现在的脉络清晰,一开始是很难想像自己能够弄清楚SEAL全同态加密的相关内容。从10月1号开始,我保持着每周一篇博客的原则,是可以写的更多的,但是我需要明白,作为一个队长,需要做到的不仅仅是自己的马不停蹄、自己的成绩出色,更重要的是让团队进度加快,让团队中的每个人都能真正投身于SEAL全同态加密开源库的代码分析中去。所以我放缓了自己的脚步,每周统计团队每个人的进度,每周花时间去研究如何正确的分配好每周的任务,每个人的基础和动力是不同的,如果不考虑基础去平均分配任务,可能会增加部分队员的压力,所以需要花费时间研究如何合理分配任务,我认为这也是课程中队长应尽的职责。
整体来说,收获颇丰,学到了密码学相关方向的前沿科技,让我对于加密解密有了一定的了解和兴趣;熟悉了作为团队中的领导者,如何分配任务和调度整个团队的积极性。
最后,感谢孔老师的指导,感谢戴老师和其他审核老师的阅读!