HEAAN源码(二)

HEAAN源码(二)

Scheme.h

#ifndef HEAAN_SCHEME_H_
#define HEAAN_SCHEME_H_

#include <NTL/RR.h>
#include <NTL/ZZ.h>
#include <complex>
#include <string>

#include "BootContext.h"
#include "SecretKey.h"
#include "Ciphertext.h"
#include "Plaintext.h"
#include "Key.h"
#include "EvaluatorUtils.h"
#include "Ring.h"

namespace heaan {
   

static long ENCRYPTION = 0;
static long MULTIPLICATION  = 1;
static long CONJUGATION = 2;

class Scheme {
   
private:
public:
	Ring& ring;

	bool isSerialized;

	std::map<long, Key*> keyMap; ///< contain Encryption, Multiplication and Conjugation keys, if generated
	std::map<long, Key*> leftRotKeyMap; ///< contain left rotation keys, if generated

	std::map<long, std::string> serKeyMap; ///< contain Encryption, Multiplication and Conjugation keys, if generated
	std::map<long, std::string> serLeftRotKeyMap; ///< contain left rotation keys, if generated

	Scheme(SecretKey& secretKey, Ring& ring, bool isSerialized = false);

	virtual ~Scheme();

	//----------------------------------------------------------------------------------
	//   KEYS GENERATION
	//----------------------------------------------------------------------------------


	void addEncKey(SecretKey& secretKey);

	void addMultKey(SecretKey& secretKey);

	void addConjKey(SecretKey& secretKey);

	void addLeftRotKey(SecretKey& secretKey, long r);

	void addRightRotKey(SecretKey& secretKey, long r);

	void addLeftRotKeys(SecretKey& secretKey);

	void addRightRotKeys(SecretKey& secretKey);

	void addBootKey(SecretKey& secretKey, long logl, long logp);


	//----------------------------------------------------------------------------------
	//   ENCODING & DECODING
	//----------------------------------------------------------------------------------


	void encode(Plaintext& plain, std::complex<double>* vals, long n, long logp, long logq);

	void encode(Plaintext& plain, double* vals, long n, long logp, long logq);

	std::complex<double>* decode(Plaintext& plain);

	void encodeSingle(Plaintext& plain, std::complex<double> val, long logp, long logq);

	void encodeSingle(Plaintext& plain, double val, long logp, long logq);

	std::complex<double> decodeSingle(Plaintext& plain);


	//----------------------------------------------------------------------------------
	//   ENCRYPTION & DECRYPTION
	//----------------------------------------------------------------------------------


	void encryptMsg(Ciphertext& cipher, Plaintext& plain);

	void decryptMsg(Plaintext& plain, SecretKey& secretKey, Ciphertext& cipher);

	void encrypt(Ciphertext& cipher, std::complex<double>* vals, long n, long logp, long logq);

	void encrypt(Ciphertext& cipher, double* vals, long n, long logp, long logq);

	void encryptBySk(Ciphertext& cipher, SecretKey& secretKey, std::complex<double>* vals, long n, long logp, long logq, double=3.2);

	void encryptBySk(Ciphertext& cipher, SecretKey& secretKey, double* vals, long n, long logp, long logq, double=3.2);

	void encryptZeros(Ciphertext& cipher, long n, long logp, long logq);

	std::complex<double>* decrypt(SecretKey& secretKey, Ciphertext& cipher);

	std::complex<double>* decryptForShare(SecretKey& secretKey, Ciphertext& cipher, long=0);

	void encryptSingle(Ciphertext& cipher, std::complex<double> val, long logp, long logq);

	void encryptSingle(Ciphertext& cipher, double val, long logp, long logq);

	std::complex<double> decryptSingle(SecretKey& secretKey, Ciphertext& cipher);


	//----------------------------------------------------------------------------------
	//   HOMOMORPHIC OPERATIONS
	//----------------------------------------------------------------------------------

	void negate(Ciphertext& res, Ciphertext& cipher);

	void negateAndEqual(Ciphertext& cipher);

	void add(Ciphertext& res, Ciphertext& cipher1, Ciphertext& cipher2);

	void addAndEqual(Ciphertext& cipher1, Ciphertext& cipher2);

	void addConst(Ciphertext& res, Ciphertext& cipher, double cnst, long logp);

	void addConst(Ciphertext& res, Ciphertext& cipher, NTL::RR& cnst, long logp);

	void addConst(Ciphertext& res, Ciphertext& cipher, std::complex<double> cnst, long logp);

	void addConstAndEqual(Ciphertext& cipher, double cnst, long logp);

	void addConstAndEqual(Ciphertext& cipher, NTL::RR& cnst, long logp);

	void addConstAndEqual(Ciphertext& cipher, std::complex<double> cnst, long logp);

	void sub(Ciphertext& res, Ciphertext& cipher1, Ciphertext& cipher2);

	void subAndEqual(Ciphertext& cipher1, Ciphertext& cipher2);

	void subAndEqual2(Ciphertext& cipher1, Ciphertext& cipher2);

	void imult(Ciphertext& res, Ciphertext& cipher);

	void idiv(Ciphertext& res, Ciphertext& cipher);

	void imultAndEqual(Ciphertext& cipher);

	void idivAndEqual(Ciphertext& cipher);

	void mult(Ciphertext& res, Ciphertext& cipher1, Ciphertext& cipher2);

	void multAndEqual(Ciphertext& cipher1, Ciphertext& cipher2);

	void square(Ciphertext& res, Ciphertext& cipher);

	void squareAndEqual(Ciphertext& cipher);

	void multByConst(Ciphertext& res, Ciphertext& cipher, double cnst, long logp);

	void multByConst(Ciphertext& res, Ciphertext& cipher, std::complex<double> cnst, long logp);

	void multByConstVec(Ciphertext& res, Ciphertext& cipher, std::complex<double>* cnstVec, long logp);

	void multByConstVecAndEqual(Ciphertext& cipher, std::complex<double>* cnstVec, long logp);

	void multByConstAndEqual(Ciphertext& cipher, double cnst, long logp);

	void multByConstAndEqual(Ciphertext& cipher, NTL::RR& cnst, long logp);

	void multByConstAndEqual(Ciphertext& cipher, std::complex<double> cnst, long logp);

	void multByPoly(Ciphertext& res, Ciphertext& cipher, NTL::ZZ* poly, long logp);

	void multByPolyNTT(Ciphertext& res, Ciphertext& cipher, uint64_t* rpoly, long bnd, long logp);

	void multByPolyAndEqual(Ciphertext& cipher, NTL::ZZ* poly, long logp);

	void multByPolyNTTAndEqual(Ciphertext& cipher, uint64_t* rpoly, long bnd, long logp);

	void multByMonomial(Ciphertext& res, Ciphertext& cipher, const long degree);

	void multByMonomialAndEqual(Ciphertext& cipher, const long degree);

	void leftShift(Ciphertext& res, Ciphertext& cipher, long bits);

	void leftShiftAndEqual(Ciphertext& cipher, long bits);

	void doubleAndEqual(Ciphertext& cipher);

	void divByPo2(Ciphertext& res, Ciphertext& cipher, long bits);

	void divByPo2AndEqual(Ciphertext& cipher, long bits);


	//----------------------------------------------------------------------------------
	//   RESCALING
	//----------------------------------------------------------------------------------


	void reScaleBy(Ciphertext& res, Ciphertext& cipher, long dlogq);

	void reScaleTo(Ciphertext& res, Ciphertext& cipher, long logq);

	void reScaleByAndEqual(Ciphertext& cipher, long dlogq);

	void reScaleToAndEqual(Ciphertext& cipher, long logq);

	void modDownBy(Ciphertext& res, Ciphertext& cipher, long dlogq);

	void modDownByAndEqual(Ciphertext& cipher, long dlogq);

	void modDownTo(Ciphertext& res, Ciphertext& cipher, long logq);

	void modDownToAndEqual(Ciphertext& cipher, long logq);


	//----------------------------------------------------------------------------------
	//   ROTATIONS & CONJUGATIONS
	//----------------------------------------------------------------------------------


	void leftRotateFast(Ciphertext& res, Ciphertext& cipher, long r);
	void rightRotateFast(Ciphertext& res, Ciphertext& cipher, long r);

	void leftRotateFastAndEqual(Ciphertext& cipher, long r);
	void rightRotateFastAndEqual(Ciphertext& cipher, long r);

	void conjugate(Ciphertext& res, Ciphertext& cipher);
	void conjugateAndEqual(Ciphertext& cipher);


	//----------------------------------------------------------------------------------
	//   BOOTSTRAPPING
	//----------------------------------------------------------------------------------


	void normalizeAndEqual(Ciphertext& cipher);

	void coeffToSlotAndEqual(Ciphertext& cipher);

	void slotToCoeffAndEqual(Ciphertext& cipher);

	void exp2piAndEqual(Ciphertext& cipher, long logp);

	void evalExpAndEqual(Ciphertext& cipher, long logT, long logI = 4);

	void bootstrapAndEqual(Ciphertext& cipher, long logq, long logQ, long logT, long logI = 4);
};

}  // namespace heaan

#endif

Scheme.cpp

/*
* Copyright (c) by CryptoLab inc.
* This program is licensed under a
* Creative Commons Attribution-NonCommercial 3.0 Unported License.
* You should have received a copy of the license along with this
* work.  If not, see <http://creativecommons.org/licenses/by-nc/3.0/>.
*/
#include "Scheme.h"

#include "NTL/BasicThreadPool.h"
#include <string>

#include "StringUtils.h"
#include "SerializationUtils.h"

using namespace std;
using namespace NTL;

namespace heaan {
   

Scheme::Scheme(SecretKey& secretKey, Ring& ring, bool isSerialized) : ring(ring), isSerialized(isSerialized) {
   
	addEncKey(secretKey);
	addMultKey(secretKey);
};

Scheme::~Scheme() {
   
  for (auto const& t : keyMap)
	delete t.second;
  for (auto const& t : leftRotKeyMap)
	delete t.second;
}

void Scheme::addEncKey(SecretKey& secretKey) {
   
	ZZ* ax = new ZZ[N];
	ZZ* bx = new ZZ[N];

	long np = ceil((1 + logQQ + logN + 2)/(double)pbnd);
	ring.sampleUniform2(ax, logQQ);
	ring.mult(bx, secretKey.sx, ax, np, QQ);
	ring.subFromGaussAndEqual(bx, QQ);

	Key* key = new Key();
	ring.CRT(key->rax, ax, nprimes);
	ring.CRT(key->rbx, bx, nprimes);
	delete[] ax; delete[] bx;

	if(isSerialized) {
   
		string path = "serkey/ENCRYPTION.txt";
		SerializationUtils::writeKey(key, path);
		serKeyMap.insert(pair<long, string>(ENCRYPTION, path));
		delete key;
	} else {
   
		keyMap.insert(pair<long, Key*>(ENCRYPTION, key));
	}
}

void Scheme::addMultKey(SecretKey& secretKey) {
   
	ZZ* ax = new ZZ[N];
	ZZ* bx = new ZZ[N];
	ZZ* sxsx = new ZZ[N];

	long np = ceil((1 + logQQ + logN + 2)/(double)pbnd);
	ring.sampleUniform2(ax, logQQ);
	ring.mult(bx, secretKey.sx, ax, np, QQ);
	ring.subFromGaussAndEqual(bx, QQ);

	np = ceil((2 + logN + 2)/(double)pbnd);
	ring.mult(sxsx, secretKey.sx, secretKey.sx, np, Q);
	ring.leftShiftAndEqual(sxsx, logQ, QQ);
	ring.addAndEqual(bx, sxsx, QQ);
	delete[] sxsx;

	Key* key = new Key();
	ring.CRT(key->rax, ax, nprimes);
	ring.CRT(key->rbx, bx, nprimes);
	delete[] ax; delete[] bx;
	if(isSerialized) {
   
		string path = "serkey/MULTIPLICATION.txt";
		SerializationUtils::writeKey(key, path);
		serKeyMap.insert(pair<long, string>(MULTIPLICATION, path));
		delete key;
	} else {
   
		keyMap.insert(pair<long, Key*>(MULTIPLICATION, key));
	}
}

void Scheme::addConjKey(SecretKey& secretKey) {
   
	ZZ* ax = new ZZ[N];
	ZZ* bx = new ZZ[N];

	long np = ceil((1 + logQQ + logN + 2)/(double)pbnd);
	ring.sampleUniform2(ax, logQQ);
	ring.mult(bx, secretKey.sx, ax, np, QQ);
	ring.subFromGaussAndEqual(bx, QQ);

	ZZ* sxconj = new ZZ[N];
	ring.conjugate(sxconj, secretKey.sx);
	ring.leftShiftAndEqual(sxconj, logQ, QQ);
	ring.addAndEqual(bx, sxconj, QQ);
	delete[] sxconj;

	Key* key = new Key();
	ring.CRT(key->rax, ax, nprimes);
	ring.CRT(key->rbx, bx, nprimes);
	delete[] ax; delete[] bx;

	if(isSerialized) {
   
		string path = "serkey/CONJUGATION.txt";
		SerializationUtils::writeKey(key, path);
		serKeyMap.insert(pair<long, string>(CONJUGATION, path));
		delete key;
	} else {
   
		keyMap.insert(pair<long, Key*>(CONJUGATION, key));
	}
}

void Scheme::addLeftRotKey(SecretKey& secretKey, long r) {
   
	ZZ* ax = new ZZ[N];
	ZZ* bx = new ZZ[N];

	long np = ceil((1 + logQQ + logN + 2)/(double)pbnd);
	ring.sampleUniform2(ax, logQQ);
	ring.mult(bx, secretKey.sx, ax, np, QQ);
	ring.subFromGaussAndEqual(bx, QQ);

	ZZ* spow = new ZZ[N];
	ring.leftRotate(spow, secretKey.sx, r);
	ring.leftShiftAndEqual(spow, logQ, QQ);
	ring.addAndEqual(bx, spow, QQ);
	delete[] spow;

	Key* key = new Key();
	ring.CRT(key->rax, ax, nprimes);
	ring.CRT(key->rbx, bx, nprimes);
	delete[] ax; delete[] bx;

	if(isSerialized) {
   
		string path = "serkey/ROTATION_" + to_string(r) + ".txt";
		SerializationUtils::writeKey(key, path);
		serLeftRotKeyMap.insert(pair<long, string>(r, path));
		delete key;
	} else {
   
		leftRotKeyMap.insert(pair<long, Key*>(r, key));
	}
}

void Scheme::addRightRotKey(SecretKey& secretKey, long r) {
   
	long idx = Nh - r;
	if(leftRotKeyMap.find(idx) == leftRotKeyMap.end
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值