HEAAN源码(二)
Scheme.h
#ifndef HEAAN_SCHEME_H_
#define HEAAN_SCHEME_H_
#include <NTL/RR.h>
#include <NTL/ZZ.h>
#include <complex>
#include <string>
#include "BootContext.h"
#include "SecretKey.h"
#include "Ciphertext.h"
#include "Plaintext.h"
#include "Key.h"
#include "EvaluatorUtils.h"
#include "Ring.h"
namespace heaan {
static long ENCRYPTION = 0;
static long MULTIPLICATION = 1;
static long CONJUGATION = 2;
class Scheme {
private:
public:
Ring& ring;
bool isSerialized;
std::map<long, Key*> keyMap;
std::map<long, Key*> leftRotKeyMap;
std::map<long, std::string> serKeyMap;
std::map<long, std::string> serLeftRotKeyMap;
Scheme(SecretKey& secretKey, Ring& ring, bool isSerialized = false);
virtual ~Scheme();
void addEncKey(SecretKey& secretKey);
void addMultKey(SecretKey& secretKey);
void addConjKey(SecretKey& secretKey);
void addLeftRotKey(SecretKey& secretKey, long r);
void addRightRotKey(SecretKey& secretKey, long r);
void addLeftRotKeys(SecretKey& secretKey);
void addRightRotKeys(SecretKey& secretKey);
void addBootKey(SecretKey& secretKey, long logl, long logp);
void encode(Plaintext& plain, std::complex<double>* vals, long n, long logp, long logq);
void encode(Plaintext& plain, double* vals, long n, long logp, long logq);
std::complex<double>* decode(Plaintext& plain);
void encodeSingle(Plaintext& plain, std::complex<double> val, long logp, long logq);
void encodeSingle(Plaintext& plain, double val, long logp, long logq);
std::complex<double> decodeSingle(Plaintext& plain);
void encryptMsg(Ciphertext& cipher, Plaintext& plain);
void decryptMsg(Plaintext& plain, SecretKey& secretKey, Ciphertext& cipher);
void encrypt(Ciphertext& cipher, std::complex<double>* vals, long n, long logp, long logq);
void encrypt(Ciphertext& cipher, double* vals, long n, long logp, long logq);
void encryptBySk(Ciphertext& cipher, SecretKey& secretKey, std::complex<double>* vals, long n, long logp, long logq, double=3.2);
void encryptBySk(Ciphertext& cipher, SecretKey& secretKey, double* vals, long n, long logp, long logq, double=3.2);
void encryptZeros(Ciphertext& cipher, long n, long logp, long logq);
std::complex<double>* decrypt(SecretKey& secretKey, Ciphertext& cipher);
std::complex<double>* decryptForShare(SecretKey& secretKey, Ciphertext& cipher, long=0);
void encryptSingle(Ciphertext& cipher, std::complex<double> val, long logp, long logq);
void encryptSingle(Ciphertext& cipher, double val, long logp, long logq);
std::complex<double> decryptSingle(SecretKey& secretKey, Ciphertext& cipher);
void negate(Ciphertext& res, Ciphertext& cipher);
void negateAndEqual(Ciphertext& cipher);
void add(Ciphertext& res, Ciphertext& cipher1, Ciphertext& cipher2);
void addAndEqual(Ciphertext& cipher1, Ciphertext& cipher2);
void addConst(Ciphertext& res, Ciphertext& cipher, double cnst, long logp);
void addConst(Ciphertext& res, Ciphertext& cipher, NTL::RR& cnst, long logp);
void addConst(Ciphertext& res, Ciphertext& cipher, std::complex<double> cnst, long logp);
void addConstAndEqual(Ciphertext& cipher, double cnst, long logp);
void addConstAndEqual(Ciphertext& cipher, NTL::RR& cnst, long logp);
void addConstAndEqual(Ciphertext& cipher, std::complex<double> cnst, long logp);
void sub(Ciphertext& res, Ciphertext& cipher1, Ciphertext& cipher2);
void subAndEqual(Ciphertext& cipher1, Ciphertext& cipher2);
void subAndEqual2(Ciphertext& cipher1, Ciphertext& cipher2);
void imult(Ciphertext& res, Ciphertext& cipher);
void idiv(Ciphertext& res, Ciphertext& cipher);
void imultAndEqual(Ciphertext& cipher);
void idivAndEqual(Ciphertext& cipher);
void mult(Ciphertext& res, Ciphertext& cipher1, Ciphertext& cipher2);
void multAndEqual(Ciphertext& cipher1, Ciphertext& cipher2);
void square(Ciphertext& res, Ciphertext& cipher);
void squareAndEqual(Ciphertext& cipher);
void multByConst(Ciphertext& res, Ciphertext& cipher, double cnst, long logp);
void multByConst(Ciphertext& res, Ciphertext& cipher, std::complex<double> cnst, long logp);
void multByConstVec(Ciphertext& res, Ciphertext& cipher, std::complex<double>* cnstVec, long logp);
void multByConstVecAndEqual(Ciphertext& cipher, std::complex<double>* cnstVec, long logp);
void multByConstAndEqual(Ciphertext& cipher, double cnst, long logp);
void multByConstAndEqual(Ciphertext& cipher, NTL::RR& cnst, long logp);
void multByConstAndEqual(Ciphertext& cipher, std::complex<double> cnst, long logp);
void multByPoly(Ciphertext& res, Ciphertext& cipher, NTL::ZZ* poly, long logp);
void multByPolyNTT(Ciphertext& res, Ciphertext& cipher, uint64_t* rpoly, long bnd, long logp);
void multByPolyAndEqual(Ciphertext& cipher, NTL::ZZ* poly, long logp);
void multByPolyNTTAndEqual(Ciphertext& cipher, uint64_t* rpoly, long bnd, long logp);
void multByMonomial(Ciphertext& res, Ciphertext& cipher, const long degree);
void multByMonomialAndEqual(Ciphertext& cipher, const long degree);
void leftShift(Ciphertext& res, Ciphertext& cipher, long bits);
void leftShiftAndEqual(Ciphertext& cipher, long bits);
void doubleAndEqual(Ciphertext& cipher);
void divByPo2(Ciphertext& res, Ciphertext& cipher, long bits);
void divByPo2AndEqual(Ciphertext& cipher, long bits);
void reScaleBy(Ciphertext& res, Ciphertext& cipher, long dlogq);
void reScaleTo(Ciphertext& res, Ciphertext& cipher, long logq);
void reScaleByAndEqual(Ciphertext& cipher, long dlogq);
void reScaleToAndEqual(Ciphertext& cipher, long logq);
void modDownBy(Ciphertext& res, Ciphertext& cipher, long dlogq);
void modDownByAndEqual(Ciphertext& cipher, long dlogq);
void modDownTo(Ciphertext& res, Ciphertext& cipher, long logq);
void modDownToAndEqual(Ciphertext& cipher, long logq);
void leftRotateFast(Ciphertext& res, Ciphertext& cipher, long r);
void rightRotateFast(Ciphertext& res, Ciphertext& cipher, long r);
void leftRotateFastAndEqual(Ciphertext& cipher, long r);
void rightRotateFastAndEqual(Ciphertext& cipher, long r);
void conjugate(Ciphertext& res, Ciphertext& cipher);
void conjugateAndEqual(Ciphertext& cipher);
void normalizeAndEqual(Ciphertext& cipher);
void coeffToSlotAndEqual(Ciphertext& cipher);
void slotToCoeffAndEqual(Ciphertext& cipher);
void exp2piAndEqual(Ciphertext& cipher, long logp);
void evalExpAndEqual(Ciphertext& cipher, long logT, long logI = 4);
void bootstrapAndEqual(Ciphertext& cipher, long logq, long logQ, long logT, long logI = 4);
};
}
#endif
Scheme.cpp
#include "Scheme.h"
#include "NTL/BasicThreadPool.h"
#include <string>
#include "StringUtils.h"
#include "SerializationUtils.h"
using namespace std;
using namespace NTL;
namespace heaan {
Scheme::Scheme(SecretKey& secretKey, Ring& ring, bool isSerialized) : ring(ring), isSerialized(isSerialized) {
addEncKey(secretKey);
addMultKey(secretKey);
};
Scheme::~Scheme() {
for (auto const& t : keyMap)
delete t.second;
for (auto const& t : leftRotKeyMap)
delete t.second;
}
void Scheme::addEncKey(SecretKey& secretKey) {
ZZ* ax = new ZZ[N];
ZZ* bx = new ZZ[N];
long np = ceil((1 + logQQ + logN + 2)/(double)pbnd);
ring.sampleUniform2(ax, logQQ);
ring.mult(bx, secretKey.sx, ax, np, QQ);
ring.subFromGaussAndEqual(bx, QQ);
Key* key = new Key();
ring.CRT(key->rax, ax, nprimes);
ring.CRT(key->rbx, bx, nprimes);
delete[] ax; delete[] bx;
if(isSerialized) {
string path = "serkey/ENCRYPTION.txt";
SerializationUtils::writeKey(key, path);
serKeyMap.insert(pair<long, string>(ENCRYPTION, path));
delete key;
} else {
keyMap.insert(pair<long, Key*>(ENCRYPTION, key));
}
}
void Scheme::addMultKey(SecretKey& secretKey) {
ZZ* ax = new ZZ[N];
ZZ* bx = new ZZ[N];
ZZ* sxsx = new ZZ[N];
long np = ceil((1 + logQQ + logN + 2)/(double)pbnd);
ring.sampleUniform2(ax, logQQ);
ring.mult(bx, secretKey.sx, ax, np, QQ);
ring.subFromGaussAndEqual(bx, QQ);
np = ceil((2 + logN + 2)/(double)pbnd);
ring.mult(sxsx, secretKey.sx, secretKey.sx, np, Q);
ring.leftShiftAndEqual(sxsx, logQ, QQ);
ring.addAndEqual(bx, sxsx, QQ);
delete[] sxsx;
Key* key = new Key();
ring.CRT(key->rax, ax, nprimes);
ring.CRT(key->rbx, bx, nprimes);
delete[] ax; delete[] bx;
if(isSerialized) {
string path = "serkey/MULTIPLICATION.txt";
SerializationUtils::writeKey(key, path);
serKeyMap.insert(pair<long, string>(MULTIPLICATION, path));
delete key;
} else {
keyMap.insert(pair<long, Key*>(MULTIPLICATION, key));
}
}
void Scheme::addConjKey(SecretKey& secretKey) {
ZZ* ax = new ZZ[N];
ZZ* bx = new ZZ[N];
long np = ceil((1 + logQQ + logN + 2)/(double)pbnd);
ring.sampleUniform2(ax, logQQ);
ring.mult(bx, secretKey.sx, ax, np, QQ);
ring.subFromGaussAndEqual(bx, QQ);
ZZ* sxconj = new ZZ[N];
ring.conjugate(sxconj, secretKey.sx);
ring.leftShiftAndEqual(sxconj, logQ, QQ);
ring.addAndEqual(bx, sxconj, QQ);
delete[] sxconj;
Key* key = new Key();
ring.CRT(key->rax, ax, nprimes);
ring.CRT(key->rbx, bx, nprimes);
delete[] ax; delete[] bx;
if(isSerialized) {
string path = "serkey/CONJUGATION.txt";
SerializationUtils::writeKey(key, path);
serKeyMap.insert(pair<long, string>(CONJUGATION, path));
delete key;
} else {
keyMap.insert(pair<long, Key*>(CONJUGATION, key));
}
}
void Scheme::addLeftRotKey(SecretKey& secretKey, long r) {
ZZ* ax = new ZZ[N];
ZZ* bx = new ZZ[N];
long np = ceil((1 + logQQ + logN + 2)/(double)pbnd);
ring.sampleUniform2(ax, logQQ);
ring.mult(bx, secretKey.sx, ax, np, QQ);
ring.subFromGaussAndEqual(bx, QQ);
ZZ* spow = new ZZ[N];
ring.leftRotate(spow, secretKey.sx, r);
ring.leftShiftAndEqual(spow, logQ, QQ);
ring.addAndEqual(bx, spow, QQ);
delete[] spow;
Key* key = new Key();
ring.CRT(key->rax, ax, nprimes);
ring.CRT(key->rbx, bx, nprimes);
delete[] ax; delete[] bx;
if(isSerialized) {
string path = "serkey/ROTATION_" + to_string(r) + ".txt";
SerializationUtils::writeKey(key, path);
serLeftRotKeyMap.insert(pair<long, string>(r, path));
delete key;
} else {
leftRotKeyMap.insert(pair<long, Key*>(r, key));
}
}
void Scheme::addRightRotKey(SecretKey& secretKey, long r) {
long idx = Nh - r;
if(leftRotKeyMap.find(idx) == leftRotKeyMap.end