#include<iostream>
#include <cstdlib>
#include <cmath>
using namespace std;
int pow_mod(int a, int b, int p); //pow and mod运算 --核心 a的b次方mod p
void encryption(int m, int pub_key, int p, int g, int* c1, int* c2); //加密算法
//m--密文 pub--公钥 p--素数 g--本原元 c1、c2 --密文
int decryption(int c1, int c2, int pub_key, int p, int g); // 解密算法
bool is_prime(int p); //判断是否为素数
//main函数测试
int main() {
int p;//素数
int g ; //本元根
cout << "Please input g:";
cin >> g;
cout << endl;
do {
cout << "Please enter a prime number: ";
cin>>p;
} while (!is_prime(p));
cout << endl;
cout << "Enter the private key of user A: "; //A的私钥
int key_A;
cin>>key_A;
cout << endl;
int pub; // A的公钥
pub = pow_mod(g, key_A, p);
cout << "the public key of user A: "<<pub<<endl;
cout << endl;
cout << "Input plaintext(smaller than "<< p<<"): ";
int m; // 明文
cin>>m;
cout << endl;
int c1, c2; // 密文
encryption(m, pub, p, g, &c1, &c2);
cout << "The ciphertext encrypted with the public key is:" << endl;
cout << "c1= " << c1 << " " << "c2= " <<c2 << endl;;
cout << endl;
int m_ = decryption(c1, c2, key_A, p, g);
cout << "Plaintext decrypted with private key:"<< m_ << endl;
return 0;
}
//加密算法
void encryption(int m, int pub_key, int p, int g, int* c1, int* c2) {
int k; //随机数
cout << "Please input a random data: ";
cin >> k;
*c1 = pow_mod(g, k, p);
*c2 = m * pow_mod(pub_key, k, p) % p;
}
//解密算法
int decryption(int c1, int c2, int pub_key, int p, int g) {
int m;
int c1_ = pow_mod(c1, p - 2, p);
m = c2 * pow_mod(c1_, pub_key, p) % p;
return m;
}
//判断是否为素数
bool is_prime(int p) {
int i;
for (i = 2; i <= sqrt(p); i++) {
if (p % i == 0)
return false;
}
return true;
}
//pow and mod运算 --核心
int pow_mod(int a, int b, int p) {
int ans = 1;
int tmp = a % p;
while (b) {
if (b & 1)
ans = ans * tmp % p;
b >>= 1;
tmp = tmp * tmp % p;
}
return ans % p;
}