Implementation of Pohlig-Hellman unable to solve for large exponents - math

I'm trying to create an implementation of the Pohlig-Hellman algorithm in order to create a utility to craft / exploit backdoors in implementations of the Diffie-Hellman protocol. This project is inspired by this 2016 white-paper by NCC Group.
I currently have an implementation, here, that works for relatively small exponents – i.e. Given a linear congruence, g^x = h (mod n), for some specially-crafted modulus, n = pq, where p and q are prime, my implementation can solve for values of x smaller than min{ p, q }.
However, if x is larger than the smallest prime factor of n, then my implementation will give an incorrect solution. I suspect that the issue may not be with my implementation of Pohlig-Hellman, itself, but with the arguments I am passing to it. All the code can be found at the link, provided above, but I'll copy the relevant code snippets, here:
#
# Implementation of Pohlig-Hellman algorithm
#
# The `crt` function implements the Chinese Remainder Theorem, and the `pollard` function implements
# Pollard's Rho algorithm for discrete logarithms (see /dph/crt.py and /dph/pollard.py).
#
def pohlig(G, H, P, factors):
g = [pow(G, divexact(P - 1, f), P) for f in factors]
h = [pow(H, divexact(P - 1, f), P) for f in factors]
if Config.verbose:
x = []
total = len(factors)
for i, (gi, hi) in enumerate(zip(g, h), start=1):
print('Solving discrete logarithm {}/{}...'.format(str(i).rjust(len(str(total))), total))
result = pollard(gi, hi, P)
x.append(result)
print(f'x = 0x{result.digits(16)}')
else:
x = [pollard(gi, hi, P) for gi, hi in zip(g, h)]
return crt(x, factors)
Above is my implementation of Pohlig-Hellman, and below is where I call it to exploit a backdoor in some implementation of the Diffie-Hellman protocol.
def _exp(args):
g = args.g
h = args.h
p_factors = list(map(mpz, args.p_factors.split(',')))
try:
p_factors.remove(2)
except ValueError:
pass
q_factors = list(map(mpz, args.q_factors.split(',')))
try:
q_factors.remove(2)
except ValueError:
pass
p = 2 * _product(*p_factors) + 1
q = 2 * _product(*q_factors) + 1
if Config.verbose:
print(f'p = 0x{p.digits(16)}')
print(f'q = 0x{q.digits(16)}')
print()
print(f'Compute the discrete logarithm modulo `p`')
print(f'-----------------------------------------')
px = pohlig(g % p, h % p, p, p_factors)
if Config.verbose:
print()
print(f'Compute the discrete logarithm modulo `q`')
print(f'-----------------------------------------')
qx = pohlig(g % q, h % q, q, q_factors)
if Config.verbose:
print()
x = crt([px, qx], [p, q])
print(f'x = 0x{x.digits(16)}')
Here is a summary of what I am doing:
Choose a prime p = 2 * prod{ p_i } + 1, where p_i denotes a set of primes.
Choose a prime q = 2 * prod{ q_j } + 1, where q_j denotes a set of primes.
Inject n = pq as the backdoor modulus in some implementation of Diffie-Hellman.
Wait for a victim (e.g. Alice computes A = g^a (mod n), and Bob computes B = g^b (mod n)).
Solve for Alice's or Bob's secret exponent, a or b, and compute their shared secret key, K = A^b = B^a (mod n).
Step #5 is done by performing the Pohlig-Hellman algorithm twice to solve for x (mod p) and x (mod q), and then the Chinese Remainder Theorem is used to solve for x (mod n).
EDIT
The x that I am referring to in the description of step #5 is either Alice's secret exponent, a, or Bob's secret exponent, b, depending on which we choose to solve for, since only one is needed to compute the shared secret key, K.

Related

Concatenation of binary representation of first n positive integers in O(logn) time complexity

I came across this question in a coding competition. Given a number n, concatenate the binary representation of first n positive integers and return the decimal value of the resultant number formed. Since the answer can be large return answer modulo 10^9+7.
N can be as large as 10^9.
Eg:- n=4. Number formed=11011100(1=1,10=2,11=3,100=4). Decimal value of 11011100=220.
I found a stack overflow answer to this question but the problem is that it only contains a O(n) solution.
Link:- concatenate binary of first N integers and return decimal value
Since n can be up to 10^9 we need to come up with solution that is better than O(n).
Here's some Python code that provides a fast solution; it uses the same ideas as in Abhinav Mathur's post. It requires Python >= 3.8, but it doesn't use anything particularly fancy from Python, and could easily be translated into another language. You'd need to write algorithms for modular exponentiation and modular inverse if they're not already available in the target language.
First, for testing purposes, let's define the slow and obvious version:
# Modulus that results are reduced by,
M = 10 ** 9 + 7
def slow_binary_concat(n):
"""
Concatenate binary representations of 1 through n (inclusive).
Reinterpret the resulting binary string as an integer.
"""
concatenation = "".join(format(k, "b") for k in range(n + 1))
return int(concatenation, 2) % M
Checking that we get the expected result:
>>> slow_binary_concat(4)
220
>>> slow_binary_concat(10)
462911642
Now we'll write a faster version. First, we split the range [1, n) into subintervals such that within each subinterval, all numbers have the same length in binary. For example, the range [1, 10) would be split into four subintervals: [1, 2), [2, 4), [4, 8) and [8, 10). Here's a function to do that splitting:
def split_by_bit_length(n):
"""
Split the numbers in [1, n) by bit-length.
Produces triples (a, b, 2**k). Each triple represents a subinterval
[a, b) of [1, n), with a < b, all of whose elements has bit-length k.
"""
a = 1
while n > a:
b = 2 * a
yield (a, min(n, b), b)
a = b
Example output:
>>> list(split_by_bit_length(10))
[(1, 2, 2), (2, 4, 4), (4, 8, 8), (8, 10, 16)]
Now for each subinterval, the value of the concatenation of all numbers in that subinterval is represented by a fairly simple mathematical sum, which can be computed in exact form. Here's a function to compute that sum modulo M:
def subinterval_concat(a, b, l):
"""
Concatenation of values in [a, b), all of which have the same bit-length k.
l is 2**k.
Equivalently, sum(i * l**(b - 1 - i)) for i in range(a, b)) modulo M.
"""
n = b - a
inv = pow(l - 1, -1, M)
q = (pow(l, n, M) - 1) * inv
return (a * q + (q - n) * inv) % M
I won't go into the evaluation of the sum here: it's a bit off-topic for this site, and it's hard to express without a good way to render formulas. If you want the details, that's a topic for https://math.stackexchange.com, or a page of fairly simple algebra.
Finally, we want to put all the intervals together. Here's a function to do that.
def fast_binary_concat(n):
"""
Fast version of slow_binary_concat.
"""
acc = 0
for a, b, l in split_by_bit_length(n + 1):
acc = (acc * pow(l, b - a, M) + subinterval_concat(a, b, l)) % M
return acc
A comparison with the slow version shows that we get the same results:
>>> fast_binary_concat(4)
220
>>> fast_binary_concat(10)
462911642
But the fast version can easily be evaluated for much larger inputs, where using the slow version would be infeasible:
>>> fast_binary_concat(10**9)
827129560
>>> fast_binary_concat(10**18)
945204784
You just have to note a simple pattern. Taking up your example for n=4, let's gradually build the solution starting from n=1.
1 -> 1 #1
2 -> 2^2(1) + 2 #6
3 -> 2^2[2^2(1)+2] + 3 #27
4 -> 2^3{2^2[2^2(1)+2]+3} + 4 #220
If you expand the coefficients of each term for n=4, you'll get the coefficients as:
1 -> (2^3)*(2^2)*(2^2)
2 -> (2^3)*(2^2)
3 -> (2^3)
4 -> (2^0)
Let the N be total number of bits in the string representation of our required number, and D(x) be the number of bits in x. The coefficients can then be written as
1 -> 2^(N-D(1))
2 -> 2^(N-D(1)-D(2))
3 -> 2^(N-D(1)-D(2)-D(3))
... and so on
Since the value of D(x) will be the same for all x between range (2^t, 2^(t+1)-1) for some given t, you can break the problem into such ranges and solve for each range using mathematics (not iteration). Since the number of such ranges will be log2(Given N), this should work in the given time limit.
As an example, the various ranges become:
1. 1 (D(x) = 1)
2. 2-3 (D(x) = 2)
3. 4-7 (D(x) = 3)
4. 8-15 (D(x) = 4)

Is O(n^(1/logn)) actually constant?

I came across this time complexity function and according to me, it is actually constant. Please correct me if I am wrong.
n^(1/logn) => (2^m)^(1/log(2^m)) => (2^m)^(1/m) => 2
Since any n can be written as a power of 2, I can do the above simplification and prove that it is constant, right?
Assuming log is the natural log, then this is equivalent to e, not 2, but either way it's a constant.
First, let:
k = n^(1 / log n)
Then take the log of both sides:
log k = (1 / log n) * log n
So:
log k = 1
Now raise both sides to the power of e to get:
e^(log k) = e^(1)
And thus:
k = e.
Here's an alternative proof:
1 / (log n) = (log e) / (log n) = logn e by the change of base identity.
Then, nlogn e = e by the definition of the logarithm as the inverse of exponentiation.

RSA private key calculate [MADLIB]

How can I calculate the private key(d)?
The problem is that d(private key) has to be an int number but I keep getting d=0.0000152585.. Help please?
p=92092076805892533739724722602668675840671093008520241548191914215399824020372076186460768206814914423802230398410980218741906960527104568970225804374404612617736579286959865287226538692911376507934256844456333236362669879347073756238894784951597211105734179388300051579994253565459304743059533646753003894559
q=97846775312392801037224396977012615848433199640105786119757047098757998273009741128821931277074555731813289423891389911801250326299324018557072727051765547115514791337578758859803890173153277252326496062476389498019821358465433398338364421624871010292162533041884897182597065662521825095949253625730631876637
e=65537
n=9010912747277787249738727439840427055736519196538871349093408340706668231808840540195374015916168031416186859836416053338250477003776576736854137538279810042409758765948034443613881324504120707334213544491046703922409406729564516371394804946909037646047891880347940067132730874804943893719672960932378043325067514786209219718314429979032869544980643978919561908707109629612202311323626173343456843249212057093980583352634168733656443959925428846968193413110401346035535595817965624054783296380268863401241570313602685481219583686719199499297832165308522137209299081956650614940546284136240753995440003473611843518083
ϕ(n)=9010912747277787249738727439840427055736519196538871349093408340706668231808840540195374015916168031416186859836416053338250477003776576736854137538279810042409758765948034443613881324504120707334213544491046703922409406729564516371394804946909037646047891880347940067132730874804943893719672960932378043324975422709403327184574705256430200869139972885911041667158917715396802487303254097156996075042397142670178352954223188514914536999398324277997967608735996733417799016531005758767556757687357486893307313469146352244856913807372125743058937380356924926103564902568350563360552030570781449252380469826858839623524
So with the formula:
d=e-1 mod ϕ(n)
I keep getting 0.0000152585. Any ideas?
You need to use the Modular Multiplicative Inverse, not the inverse.
Here is an example in python using the cryptography module.
from cryptography.hazmat.primitives.asymmetric.rsa import _modinv
p=92092076805892533739724722602668675840671093008520241548191914215399824020372076186460768206814914423802230398410980218741906960527104568970225804374404612617736579286959865287226538692911376507934256844456333236362669879347073756238894784951597211105734179388300051579994253565459304743059533646753003894559
q=97846775312392801037224396977012615848433199640105786119757047098757998273009741128821931277074555731813289423891389911801250326299324018557072727051765547115514791337578758859803890173153277252326496062476389498019821358465433398338364421624871010292162533041884897182597065662521825095949253625730631876637
e=65537
phi = (p-1) * (q-1)
d = _modinv(e, phi)
print(d) # 1405046269503207469140791548403639533127416416214210694972085079171787580463776820425965898174272870486015739516125786182821637006600742140682552321645503743280670839819078749092730110549881891271317396450158021688253989767145578723458252769465545504142139663476747479225923933192421405464414574786272963741656223941750084051228611576708609346787101088759062724389874160693008783334605903142528824559223515203978707969795087506678894006628296743079886244349469131831225757926844843554897638786146036869572653204735650843186722732736888918789379054050122205253165705085538743651258400390580971043144644984654914856729
print((e * d) % phi) # 1
You can find the implementation of _modinv() here.
By fips.186-4 standard, you have to use λ not φ in RSA;
λ(n)=lcm(p−1,q−1)
to calculate d= e-1 mod λ(n) you must use extendedGCD algorithm.
find x and y that satisfify Bezout Identity
e x + λ(n) y = gcd(e,λ(n))
e x -1 = (-y)λ(n)
take mod λ(n)
e x ≡ 1 mod λ(n)

How to set a square root to only be whole

I cant seem to find any kind of answer to this, but if I have an equation like the square root of (X^2-4n) where 4n is a constant, how could I set x so the equation gives a whole number.
I know setting x to n+1 works, but I'm looking for an algorithm that would generate all solutions.
So, the problem is to find all pairs of integers (x, m) such that:
sqrt(x^2 - 4n) = m
We have:
x^2 - 4n = m^2
or
x^2 - mˆ2 = 4n
so
(x + m)(x - m) = 4n
Now, 2 divides 4n and so it must divide (x+m) or (x-m). But if it divides any of them it will divide the other too. Thus a := (x+m)/2 and b := (x-m)/2 are both integers. Therefore
a*b = n
So, it is just a matter of factoring n as a*b in all possible ways and recover x and m from the equations above:
x = a + b.
m = a - b.
Your solution x = n+1 corresponds to the trivial factorization n = n*1 where a=n and b=1.
UPDATE
Here is an algorithm that prints all pairs (x, m)
[Initialize] a := n.
[Check] if n % a = 0 then
b := n / a.
print(a + b), print(a - b)
[Decrement] a := a - 1.
[End?] if a * a > n go to Step 2.

How to implement c=m^e mod n for enormous numbers?

I'm trying to figure out how to implement RSA crypto from scratch (just for the intellectual exercise), and i'm stuck on this point:
For encryption, c = me mod n
Now, e is normally 65537. m and n are 1024-bit integers (eg 128-byte arrays). This is obviously too big for standard methods. How would you implement this?
I've been reading a bit about exponentiation here but it just isn't clicking for me:
Wikipedia-Exponentiation by squaring
This Chapter (see section 14.85)
Thanks.
edit: Also found this - is this more what i should be looking at? Wikipedia- Modular Exponentiation
Exponentiation by squaring:
Let's take an example. You want to find 1723. Note that 23 is 10111 in binary. Let's try to build it up from left to right.
// a exponent in binary
a = 17 //17^1 1
a = a * a //17^2 10
a = a * a //17^4 100
a = a * 17 //17^5 101
a = a * a //17^10 1010
a = a * 17 //17^11 1011
a = a * a //17^22 10110
a = a * 17 //17^23 10111
When you square, you double the exponent (shift left by 1 bit). When you multiply by m, you add 1 to the exponent.
If you want to reduce modulo n, you can do it after each multiplication (rather than leaving it to the end, which would make the numbers get very large).
65537 is 10000000000000001 in binary which makes all of this pretty easy. It's basically
a = m
repeat 16 times:
a = a * a
a = a mod n
a = a * m
a = a mod n
where of course a, n and m are "big integers". a needs to be at least 2048 bits as it can get as large as (n-1)2.
For an efficient algorithm you need to combine the exponentiation by squaring with repeated application of mod after each step.
For odd e this holds:
me mod n = m ⋅ me-1 mod n
For even e:
me mod n = (me/2 mod n)2 mod n
With m1 = m as a base case this defines a recursive way to do efficient modular exponentiation.
But even with an algorithm like this, because m and n will be very large, you will still need to use a type/library that can handle integers of such sizes.
result = 1
while e>0:
if (e & 1) != 0:
result = result * m
result = result mod n
m = m*m
m = m mod n
e = e>>1
return result
This checks bits in the exponent starting with the least significant bit. Each time we move up a bit it corresponds to doubling the power of m - hence we shift e and square m. The result only gets the power of m multiplied in if the exponent has a 1 bit in that position. All multiplications need to be reduced mod n.
As an example, consider m^13. 11 = 1101 in binary. so this is the same as m^8 * m^4 * m. Notice the powers 8,4,(not 2),1 which is the same as the bits 1101. And then recall that m^8 = (m^4)^2 and m^4 = (m^2)^2.
If g(x) = x mod 2^k is faster to calculate for your bignum library than f(x) = x mod N for N not divisible by 2, then consider using Montgomery multiplication. When used with modular exponentiation, it avoids having to calculate modulo N at each step, you just need to do the "Montgomeryization" / "un-Montgomeryization" at the beginning and end.

Resources