软件工程与应用第十三篇报告

2021SC@SDUSC

2021-12-26

第十三周完成事项

工作内容

本周的工作是分析剩余数系统的最后一块代码(RNSTool)。首先回顾一下之前了解学习过的剩余数系统。

RNS——剩余数系统

与冗余数相反,剩余数表示系统(RNS,residue number system)是一种用较少的数表示较多的数的表示系统。RNS可显著提高信号处理应用中某些算法密集型场景的算法速度。此外,RNS也是研究快速算法极限理论的一个工具。

剩余数系统的作用是将原本比较大的数,使用多个相对小的数进行表示,这样在运算的时候,可以不必对原本比较大的数进行运算,对多个相对小的数进行运算可以加快运算速度。

代码分析

上周分析了RNSbase类和BaseConverter类,这周我们来看一下RNS剩余数系统最后一个类——RNSTool类。

首先先分析一下内容组成(内容分析将在后续展开)。

让我们先来看一下类中的公共部分。

class RNSTool
{
public:
	/**
	@throws std::invalid_argument if poly_modulus_degree is out of range, coeff_modulus is not valid, or pool is
	invalid.
	@throws std::logic_error if coeff_modulus and extended bases do not support NTT or are not coprime.
	*/
	RNSTool(
		std::size_t poly_modulus_degree, const RNSBase &coeff_modulus, const Modulus &plain_modulus,
		MemoryPoolHandle pool);

    /**
    @param[in] input Must be in RNS form, i.e. coefficient must be less than the associated modulus.
    */
    void divide_and_round_q_last_inplace(RNSIter input, MemoryPoolHandle pool) const;

    void divide_and_round_q_last_ntt_inplace(
		RNSIter input, ConstNTTTablesIter rns_ntt_tables, MemoryPoolHandle pool) const;

    /**
    Shenoy-Kumaresan conversion from Bsk to q
    */
    void fastbconv_sk(ConstRNSIter input, RNSIter destination, MemoryPoolHandle pool) const;

    /**
    Montgomery reduction mod q; changes base from Bsk U {m_tilde} to Bsk
    */
	void sm_mrq(ConstRNSIter input, RNSIter destination, MemoryPoolHandle pool) const;

    /**
    Divide by q and fast floor from q U Bsk to Bsk
    */
    void fast_floor(ConstRNSIter input, RNSIter destination, MemoryPoolHandle pool) const;

    /**
    Fast base conversion from q to Bsk U {m_tilde}
    */
    void fastbconv_m_tilde(ConstRNSIter input, RNSIter destination, MemoryPoolHandle pool) const;

    /**
    Compute round(t/q * |input|_q) mod t exactly
    */
    void decrypt_scale_and_round(ConstRNSIter phase, CoeffIter destination, MemoryPoolHandle pool) const;

公共的第一部分包括了RNSTool类的构造函数以及主要的功能型函数。

SEAL_NODISCARD inline auto inv_q_last_mod_q() const noexcept
{
	return inv_q_last_mod_q_.get();
}

SEAL_NODISCARD inline auto base_Bsk_ntt_tables() const noexcept
{
    return base_Bsk_ntt_tables_.get();
}

SEAL_NODISCARD inline auto base_q() const noexcept
{
    return base_q_.get();
}

SEAL_NODISCARD inline auto base_B() const noexcept
{
    return base_B_.get();
}

SEAL_NODISCARD inline auto base_Bsk() const noexcept
{
    return base_Bsk_.get();
}

SEAL_NODISCARD inline auto base_Bsk_m_tilde() const noexcept
{
    return base_Bsk_m_tilde_.get();
}

SEAL_NODISCARD inline auto base_t_gamma() const noexcept
{
	return base_t_gamma_.get();
}

SEAL_NODISCARD inline auto &m_tilde() const noexcept
{
    return m_tilde_;
}

SEAL_NODISCARD inline auto &m_sk() const noexcept
{
    return m_sk_;
}

SEAL_NODISCARD inline auto &t() const noexcept
{
    return t_;
}

SEAL_NODISCARD inline auto &gamma() const noexcept
{
    return gamma_;
}

公共的第二部分是对类中部分私有变量的获取。

private:
    RNSTool(const RNSTool &copy) = delete;

    RNSTool(RNSTool &&source) = delete;

    RNSTool &operator=(const RNSTool &assign) = delete;

    RNSTool &operator=(RNSTool &&assign) = delete;

    /**
    Generates the pre-computations for the given parameters.
    */
    void initialize(std::size_t poly_modulus_degree, const RNSBase &q, const Modulus &t);

接下来分析私有部分的函数和变量,私有的第一部分是复制构造函数和赋值构造函数以及RNSTool类的初始化函数。这里的复制构造函数和赋值构造函数也是一样的,delete删除掉了,防止对其操作过程造成影响。

MemoryPoolHandle pool_;

std::size_t coeff_count_ = 0;

Pointer<RNSBase> base_q_;

Pointer<RNSBase> base_B_;

Pointer<RNSBase> base_Bsk_;

Pointer<RNSBase> base_Bsk_m_tilde_;

Pointer<RNSBase> base_t_gamma_;

// Base converter: q --> B_sk
Pointer<BaseConverter> base_q_to_Bsk_conv_;

// Base converter: q --> {m_tilde}
Pointer<BaseConverter> base_q_to_m_tilde_conv_;

// Base converter: B --> q
Pointer<BaseConverter> base_B_to_q_conv_;

// Base converter: B --> {m_sk}
Pointer<BaseConverter> base_B_to_m_sk_conv_;

// Base converter: q --> {t, gamma}
Pointer<BaseConverter> base_q_to_t_gamma_conv_;

// prod(q)^(-1) mod Bsk
Pointer<MultiplyUIntModOperand> inv_prod_q_mod_Bsk_;

// prod(q)^(-1) mod m_tilde
MultiplyUIntModOperand neg_inv_prod_q_mod_m_tilde_;

// prod(B)^(-1) mod m_sk
MultiplyUIntModOperand inv_prod_B_mod_m_sk_;

// gamma^(-1) mod t
MultiplyUIntModOperand inv_gamma_mod_t_;

// prod(B) mod q
Pointer<std::uint64_t> prod_B_mod_q_;

// m_tilde^(-1) mod Bsk
Pointer<MultiplyUIntModOperand> inv_m_tilde_mod_Bsk_;

// prod(q) mod Bsk
Pointer<std::uint64_t> prod_q_mod_Bsk_;

// -prod(q)^(-1) mod {t, gamma}
Pointer<MultiplyUIntModOperand> neg_inv_q_mod_t_gamma_;

// prod({t, gamma}) mod q
Pointer<MultiplyUIntModOperand> prod_t_gamma_mod_q_;

// q[last]^(-1) mod q[i] for i = 0..last-1
Pointer<MultiplyUIntModOperand> inv_q_last_mod_q_;

// NTTTables for Bsk
Pointer<NTTTables> base_Bsk_ntt_tables_;

Modulus m_tilde_;

Modulus m_sk_;

Modulus t_;

Modulus gamma_;

私有的第二部分是定义了一些类中所用到的变量和指针,在其中我们可以看到之前分析的RNSbase类和BaseConverter类的相对应的指针,由此可以判断RNSTool类相当于一个对外的工具,使用RNSTool类可以使用RNSbase类和BaseConverter类的相关功能。

让我们来看一下最基本的初始化函数,因为篇幅比较长,拆成了多段分析,便于阅读。

void RNSTool::initialize(size_t poly_modulus_degree, const RNSBase &q, const Modulus &t)
{
	// Return if q is out of bounds
    if (q.size() < SEAL_COEFF_MOD_COUNT_MIN || q.size() > SEAL_COEFF_MOD_COUNT_MAX)
    {
		throw invalid_argument("rnsbase is invalid");
    }

首先判断传进来的RNSbase类的参数q大小是否符合seal系数模的标准,如果不符合则抛出异常,这里的SEAL_COEFF_MOD_COUNT_MINSEAL_COEFF_MOD_COUNT_MAX均为是事先定义好的常数。

// Return if coeff_count is not a power of two or out of bounds
int coeff_count_power = get_power_of_two(poly_modulus_degree);
if (coeff_count_power < 0 || poly_modulus_degree > SEAL_POLY_MOD_DEGREE_MAX ||
	poly_modulus_degree < SEAL_POLY_MOD_DEGREE_MIN)
{
    throw invalid_argument("poly_modulus_degree is invalid");
}

t_ = t;
coeff_count_ = poly_modulus_degree;
/**
If the value is a power of two, return the power; otherwise, return -1.
*/
SEAL_NODISCARD inline int get_power_of_two(std::uint64_t value)
{
	if (value == 0 || (value & (value - 1)) != 0)
    {
    	return -1;
    }

    unsigned long result = 0;
    SEAL_MSB_INDEX_UINT64(&result, value);
    return static_cast<int>(result);
}

第二步判断传进来的多项式模的度数是否符合规范,这里判断了一下这个度数是否为二的幂(get_power_of_two函数展示如上,如果是二的幂,则返回幂;如果不是则返回-1)。然后根据计算出来的幂coeff_count_power和多项式模的度数进行判断,如果度数不为二的幂(返回值为-1)或者多项式模的度数超过了seal多项式模度数的限度,则抛出异常,如果满足要求则将传进来的模参数t赋给类中的私有变量t_,把多项式模的系数赋给类中的私有变量coeff_count_。这里的SEAL_POLY_MOD_DEGREE_MAXSEAL_POLY_MOD_DEGREE_MIN同样也为初始定义的常数。

// Allocate memory for the bases q, B, Bsk, Bsk U m_tilde, t_gamma
size_t base_q_size = q.size();

// In some cases we might need to increase the size of the base B by one, namely we require
// K * n * t * q^2 < q * prod(B) * m_sk, where K takes into account cross terms when larger size ciphertexts
// are used, and n is the "delta factor" for the ring. We reserve 32 bits for K * n. Here the coeff modulus
// primes q_i are bounded to be SEAL_USER_MOD_BIT_COUNT_MAX (60) bits, and all primes in B and m_sk are
// SEAL_INTERNAL_MOD_BIT_COUNT (61) bits.
int total_coeff_bit_count = get_significant_bit_count_uint(q.base_prod(), q.size());

size_t base_B_size = base_q_size;
if (32 + t_.bit_count() + total_coeff_bit_count >=
	SEAL_INTERNAL_MOD_BIT_COUNT * safe_cast<int>(base_q_size) + SEAL_INTERNAL_MOD_BIT_COUNT)
{
    base_B_size++;
}

size_t base_Bsk_size = add_safe(base_B_size, size_t(1));
size_t base_Bsk_m_tilde_size = add_safe(base_Bsk_size, size_t(1));

size_t base_t_gamma_size = 0;
SEAL_NODISCARD inline int get_significant_bit_count_uint(const std::uint64_t *value, std::size_t uint64_count)
{
#ifdef SEAL_DEBUG
	if (!value && uint64_count)
    {
    	throw std::invalid_argument("value");
    }
    if (!uint64_count)
    {
        throw std::invalid_argument("uint64_count");
    }
#endif
    value += uint64_count - 1;
    for (; *value == 0 && uint64_count > 1; uint64_count--)
    {
        value--;
    }

    return static_cast<int>(uint64_count - 1) * bits_per_uint64 + get_significant_bit_count(*value);
}
template <
	typename T, typename S, typename = std::enable_if_t<std::is_arithmetic<T>::value>,
    typename = std::enable_if_t<std::is_arithmetic<S>::value>>
SEAL_NODISCARD inline T safe_cast(S value)
{
	SEAL_IF_CONSTEXPR(!std::is_same<T, S>::value)
    {
    	if (!fits_in<T>(value))
        {
        	throw std::logic_error("cast failed");
        }
    }
    return static_cast<T>(value);
}

通过RNSbase变量q,定义了一些基础变量,为后续工作做准备。这里针对base_B_size进行了分析,如果32加上之前输入的模数t的bit位加上整个q主要的bit位大于base_q_size的safe_cast加一乘以迭代器模数bit位时,会将base_B_size加一。

对于get_significant_bit_count_uint本质上就是将原本的value值,以uint64的bit位为单位进行拆分,逐步即可得到原本数据的bit位数。

// Size check
if (!product_fits_in(coeff_count_, base_Bsk_m_tilde_size))
{
	throw logic_error("invalid parameters");
}

// Sample primes for B and two more primes: m_sk and gamma
auto baseconv_primes = get_primes(coeff_count_, SEAL_INTERNAL_MOD_BIT_COUNT, base_Bsk_m_tilde_size);
auto baseconv_primes_iter = baseconv_primes.cbegin();
m_sk_ = *baseconv_primes_iter++;
gamma_ = *baseconv_primes_iter++;
vector<Modulus> base_B_primes;
copy_n(baseconv_primes_iter, base_B_size, back_inserter(base_B_primes));

下一步,检查系数数和上一步生成的base_Bsk_m_tilde_size是否符合要求,不符合则抛出异常。通过上述的两个变量以及SEAL_INTERNAL_MOD_BIT_COUNT得到一个素数,再通过素数baseconv_primes的cbegin方法得到素数的迭代器,依次为m_sk_和gamma_进行赋值,构造一个模数的向量,调用拷贝函数进行操作。

// Set m_tilde_ to a non-prime value
m_tilde_ = uint64_t(1) << 32;

// Populate the base arrays
base_q_ = allocate<RNSBase>(pool_, q, pool_);
base_B_ = allocate<RNSBase>(pool_, base_B_primes, pool_);
base_Bsk_ = allocate<RNSBase>(pool_, base_B_->extend(m_sk_));
base_Bsk_m_tilde_ = allocate<RNSBase>(pool_, base_Bsk_->extend(m_tilde_));

// Set up t-gamma base if t_ is non-zero (using BFV)
if (!t_.is_zero())
{
	base_t_gamma_size = 2;
    base_t_gamma_ = allocate<RNSBase>(pool_, vector<Modulus>{ t_, gamma_ }, pool_);
}

// Generate the Bsk NTTTables; these are used for NTT after base extension to Bsk
try
{
	CreateNTTTables(
		coeff_count_power, vector<Modulus>(base_Bsk_->base(), base_Bsk_->base() + base_Bsk_size),
        base_Bsk_ntt_tables_, pool_);
}
catch (const logic_error &)
{
    throw logic_error("invalid rns bases");
}

// Set up BaseConverter for q --> Bsk
base_q_to_Bsk_conv_ = allocate<BaseConverter>(pool_, *base_q_, *base_Bsk_, pool_);

// Set up BaseConverter for q --> {m_tilde}
base_q_to_m_tilde_conv_ = allocate<BaseConverter>(pool_, *base_q_, RNSBase({ m_tilde_ }, pool_), pool_);

// Set up BaseConverter for B --> q
base_B_to_q_conv_ = allocate<BaseConverter>(pool_, *base_B_, *base_q_, pool_);

// Set up BaseConverter for B --> {m_sk}
base_B_to_m_sk_conv_ = allocate<BaseConverter>(pool_, *base_B_, RNSBase({ m_sk_ }, pool_), pool_);

if (base_t_gamma_)
{
	// Set up BaseConverter for q --> {t, gamma}
    base_q_to_t_gamma_conv_ = allocate<BaseConverter>(pool_, *base_q_, *base_t_gamma_, pool_);
}

接下来设置一下下面模计算的参数准备,创造一个ntt(快速数论变换)的表,加快计算。

    // Compute prod(B) mod q
    prod_B_mod_q_ = allocate_uint(base_q_size, pool_);
    SEAL_ITERATE(iter(prod_B_mod_q_, base_q_->base()), base_q_size, [&](auto I) {
    	get<0>(I) = modulo_uint(base_B_->base_prod(), base_B_size, get<1>(I));
    });

    uint64_t temp;

    // Compute prod(q)^(-1) mod Bsk
    inv_prod_q_mod_Bsk_ = allocate<MultiplyUIntModOperand>(base_Bsk_size, pool_);
    for (size_t i = 0; i < base_Bsk_size; i++)
    {
		temp = modulo_uint(base_q_->base_prod(), base_q_size, (*base_Bsk_)[i]);
        if (!try_invert_uint_mod(temp, (*base_Bsk_)[i], temp))
        {
        	throw logic_error("invalid rns bases");
        }
        inv_prod_q_mod_Bsk_[i].set(temp, (*base_Bsk_)[i]);
    }

    // Compute prod(B)^(-1) mod m_sk
    temp = modulo_uint(base_B_->base_prod(), base_B_size, m_sk_);
    if (!try_invert_uint_mod(temp, m_sk_, temp))
    {
    	throw logic_error("invalid rns bases");
    }
    inv_prod_B_mod_m_sk_.set(temp, m_sk_);

    // Compute m_tilde^(-1) mod Bsk
    inv_m_tilde_mod_Bsk_ = allocate<MultiplyUIntModOperand>(base_Bsk_size, pool_);
    SEAL_ITERATE(iter(inv_m_tilde_mod_Bsk_, base_Bsk_->base()), base_Bsk_size, [&](auto I) {
    	if (!try_invert_uint_mod(barrett_reduce_64(m_tilde_.value(), get<1>(I)), get<1>(I), temp))
        {
        	throw logic_error("invalid rns bases");
        }
        get<0>(I).set(temp, get<1>(I));
    });

	// Compute prod(q)^(-1) mod m_tilde
    temp = modulo_uint(base_q_->base_prod(), base_q_size, m_tilde_);
    if (!try_invert_uint_mod(temp, m_tilde_, temp))
    {
		throw logic_error("invalid rns bases");
    }
    neg_inv_prod_q_mod_m_tilde_.set(negate_uint_mod(temp, m_tilde_), m_tilde_);

    // Compute prod(q) mod Bsk
    prod_q_mod_Bsk_ = allocate_uint(base_Bsk_size, pool_);
    SEAL_ITERATE(iter(prod_q_mod_Bsk_, base_Bsk_->base()), base_Bsk_size, [&](auto I) {
    	get<0>(I) = modulo_uint(base_q_->base_prod(), base_q_size, get<1>(I));
    });

    if (base_t_gamma_)
    {
        // Compute gamma^(-1) mod t
        if (!try_invert_uint_mod(barrett_reduce_64(gamma_.value(), t_), t_, temp))
        {
			throw logic_error("invalid rns bases");
        }
        inv_gamma_mod_t_.set(temp, t_);

        // Compute prod({t, gamma}) mod q
        prod_t_gamma_mod_q_ = allocate<MultiplyUIntModOperand>(base_q_size, pool_);
        SEAL_ITERATE(iter(prod_t_gamma_mod_q_, base_q_->base()), base_q_size, [&](auto I) {
        	get<0>(I).set(
            	multiply_uint_mod((*base_t_gamma_)[0].value(), (*base_t_gamma_)[1].value(), get<1>(I)),
                	get<1>(I));
        });

        // Compute -prod(q)^(-1) mod {t, gamma}
        neg_inv_q_mod_t_gamma_ = allocate<MultiplyUIntModOperand>(base_t_gamma_size, pool_);
        SEAL_ITERATE(iter(neg_inv_q_mod_t_gamma_, base_t_gamma_->base()), base_t_gamma_size, [&](auto I) {
        	get<0>(I).operand = modulo_uint(base_q_->base_prod(), base_q_size, get<1>(I));
            if (!try_invert_uint_mod(get<0>(I).operand, get<1>(I), get<0>(I).operand))
            {
            	throw logic_error("invalid rns bases");
            }
            get<0>(I).set(negate_uint_mod(get<0>(I).operand, get<1>(I)), get<1>(I));
        });
	}

    // Compute q[last]^(-1) mod q[i] for i = 0..last-1
    // This is used by modulus switching and rescaling
    inv_q_last_mod_q_ = allocate<MultiplyUIntModOperand>(base_q_size - 1, pool_);
    SEAL_ITERATE(iter(inv_q_last_mod_q_, base_q_->base()), base_q_size - 1, [&](auto I) {
    	if (!try_invert_uint_mod((*base_q_)[base_q_size - 1].value(), get<1>(I), temp))
        {
        	throw logic_error("invalid rns bases");
        }
        get<0>(I).set(temp, get<1>(I));
    });
}

最后是一些模的计算,这里不再过多赘述,选取其中进行简单分析。

// Compute prod(B) mod q
prod_B_mod_q_ = allocate_uint(base_q_size, pool_);
SEAL_ITERATE(iter(prod_B_mod_q_, base_q_->base()), base_q_size, [&](auto I) {
	get<0>(I) = modulo_uint(base_B_->base_prod(), base_B_size, get<1>(I));
});
SEAL_NODISCARD inline auto allocate_uint(std::size_t uint64_count, MemoryPool &pool)
{
    return allocate<std::uint64_t>(uint64_count, pool);
}

先进行内存分配,获得内存指针,通过seal_iterate迭代操作完成模的计算,其中第一个参数是我们需要进行操作的变量,第二个参数是操作时需要使用的变量,第三个为操作的具体执行步骤的公式,在公式里面get<0>(I)对应的是第一个参数,get<1>(I)对应的是第二个参数。

至此我们的RNS剩余数系统基本全部分析完成。

总结

时间已经到了这学期课程的尾声,这学期我和两位队友选择了孔老师的SEAL全同态加密开源库的这个方向,从一开始的初出茅庐、一窍不通到现在的脉络清晰,一开始是很难想像自己能够弄清楚SEAL全同态加密的相关内容。从10月1号开始,我保持着每周一篇博客的原则,是可以写的更多的,但是我需要明白,作为一个队长,需要做到的不仅仅是自己的马不停蹄、自己的成绩出色,更重要的是让团队进度加快,让团队中的每个人都能真正投身于SEAL全同态加密开源库的代码分析中去。所以我放缓了自己的脚步,每周统计团队每个人的进度,每周花时间去研究如何正确的分配好每周的任务,每个人的基础和动力是不同的,如果不考虑基础去平均分配任务,可能会增加部分队员的压力,所以需要花费时间研究如何合理分配任务,我认为这也是课程中队长应尽的职责。

整体来说,收获颇丰,学到了密码学相关方向的前沿科技,让我对于加密解密有了一定的了解和兴趣;熟悉了作为团队中的领导者,如何分配任务和调度整个团队的积极性。

最后,感谢孔老师的指导,感谢戴老师和其他审核老师的阅读!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值