Now We Can Play!!
Harekaze 2019 - crypto 200
pwn和crypto结合,源码如下:
#!/usr/bin/python3
from Crypto.Util.number import *
from Crypto.Random.random import randint
from keys import flag
def genKey(k):
p = getStrongPrime(k)
g = 2
x = randint(2, p)
h = pow(g, x, p)
return (p, g, h), x
def encrypt(m, pk):
p, g, h = pk
r = randint(2, p)
c1 = pow(g, r, p)
c2 = m * pow(h, r, p) % p
return c1, c2
def decrypt(c1, c2, pk, sk):
p = pk[0]
m = pow(3, randint(2**16, 2**17), p) * c2 * inverse(pow(c1, sk, p), p) % p
return m
def challenge():
pk, sk = genKey(1024)
m = bytes_to_long(flag)
c1, c2 = encrypt(m, pk)
print("Public Key :", pk)
print("Cipher text :", (c1, c2))
while True:
print("---"*10, "\n")
in_c1 = int(input("Input your ciphertext c1 : "))
in_c2 = int(input("Input your ciphertext c2 : "))
dec = decrypt(in_c1, in_c2, pk, sk)
print("Your Decrypted Message :", dec)
if __name__ == "__main__":
challenge()
题目逻辑比较简单,连上端口以后会打印出pk的三个值,还有两个密文c1和c2
需要还原明文 m
,首先分析一下密文 c1
和 c2
如何生成:
x = randint(2, p)
r = randint(2, p)
h = pow(g, x, p)
c1 = pow(g, r, p)
c2 = m * pow(h, r, p) % p
其中公钥 p
, g
, 和 h
都已知,推导一下可知
c2 = m * pow(g, x * r, p) % p
= m * pow(pow(g, r, p), x, p) % p
= m * pow(c1, x, p) % p
再分析一下 decrypt()
函数逻辑,最主要就是进行下面这个运算
rand = randint(2 ** 16, 2 ** 17)
m_ = pow(3, rand, p) * c2 * inverse(pow(c1, sk, p), p) % p
注意到 x == sk
再推导一下上面这个式子
x = sk
c2 = m * pow(c1, sk, p) % p
m = c2 * inverse(pow(c1, sk, p)) % p
m_ = pow(3, rand, p) * m % p
m = m_ * inverse(pow(3, rand, p)) % p
注意到上式只有一个未知量 rand
对其进行(2**16,2**17)
范围爆破即可。用户输入内容的话比较简单 in_c1 = c1
和 in_c2 = c2
,没太搞懂为啥要把输出的c1和c2再输入一遍😰
最终exp:
#python2
from pwn import *
from Crypto.Util.number import long_to_bytes, inverse
from string import printable
def decrypt(conn, c1, c2):
conn.recvuntil("Input your ciphertext c1 : ")
conn.sendline(str(c1))
conn.recvuntil("Input your ciphertext c2 : ")
conn.sendline(str(c2))
conn.recvuntil("('Your Decrypted Message :', ")
m = int(conn.recvline().rstrip("L)\n"))
return m
context.log_level = "DEBUG"
conn = remote("node4.buuoj.cn", 29607)
conn.recvuntil("('Public Key :', (")
pk = conn.recvline().split(", ")
p = int(pk[0].rstrip("L"))
g = int(pk[1])
h = int(pk[2].rstrip("L))\n"))
conn.recvuntil("('Cipher text :', (")
cs = conn.recvline().split(", ")
c1 = int(cs[0].rstrip("L"))
c2 = int(cs[1].rstrip("L))\n"))
m_ = decrypt(conn, c1, c2)
conn.close()
for i in range(2**16, 2**17):
flag = long_to_bytes(m_ * inverse(pow(3, i, p), p) % p)
if all(c in printable for c in flag):
break
print(flag)