Understanding this way of computing ( x^e mod n)

509 Views Asked by At

working on a RSA encryption in c++ and found that the pow() function in cmath was giving me the incorrect result.

after looking online I came across some code that would do the above process for me but I am having difficulty understanding it.

this is the code:

long long int modpow(long long int base, long long int exp, long long int modulus) {
    base %= modulus;
    long long int result = 1;
  while (exp > 0) {
    if (exp & 1) {
    result = (result * base) % modulus; 
    }
    base = (base * base) % modulus;
    exp >>= 1; 
  }
  return result;
}

(this code isn't the original)

I am struggling with understanding this function.

I know that exp >>=1; is a left shift 1 bit and (exp & 1) returns 1 or 0 based on the least significant bit but what i dont understand is how that contributes to the final answer.

for example:

if (exp & 1) {
    result = (result * base) % modulus; 
        }

what is the purpose of (result * base) % modulus if exp is odd?

hoping someone can please explain this function to me as I dont want to just copy it over.

1

There are 1 best solutions below

5
On

The code was written to be "clever" instead of clear. This cryptic style usually should not be done outside core libraries (where performance is crucial), and even when it is used, it would be nice to have comments explaining what is going on. Here is an annotated version of the code.

long long int modpow(long long int base, long long int exp, long long int modulus)
{
  base %= modulus;                          // Eliminate factors; keeps intermediate results smaller
  long long int result = 1;                 // Start with the multiplicative unit (a.k.a. one)
                                            // Throughout, we are calculating:
                                            // `(result*base^exp)%modulus`
  while (exp > 0) {                         // While the exponent has not been exhausted.
    if (exp & 1) {                          // If the exponent is odd
      result = (result * base) % modulus;   // Consume one application of the base, logically
                                            // (but not actually) reducing the exponent by one.
                                            // That is: result * base^exp == (result*base)*base^(exp-1)
    }
    base = (base * base) % modulus;         // The exponent is logically even. Apply B^(2n) == (B^2)^n.
    exp >>= 1;                              // The base is squared and the exponent is divided by 2 
  }
  return result;
}

Does that make more sense now?


For those wondering how this cryptic code could be made clearer, I present the following version. There are three main improvements.

First, bitwise operations have been replaced by the equivalent arithmetic operations. If one were to prove that the algorithm works, one would use arithmetic, not bitwise operators. In fact, the algorithm works regardless of how the numbers are represented – there need not be a concept of "bits", much less of "bitwise operators". Thus the natural way to implement this algorithm is with arithmetic. Using bitwise operators removes clarity with almost no benefit. Compilers are smart enough to produce identical machine code, with one exception. Since exp was declared long long int instead of long long unsigned, there is an extra step when calculating exp /= 2 compared to exp >>= 1. (I do not know why exp is signed; the function is both conceptually meaningless and technically incorrect for negative exponents.) See also .

Second, I created a helper function for improved readability. While the improvement is minor, it comes at no performance cost. I'd expect the function to be inlined by any compiler worth its salt.

// Wrapper for detecting if an integer is odd.
bool is_odd(long long int n)
{
    return n % 2 != 0; 
}

Third, comments have been added to explain what is going on. While some people (not I) might think that "the standard right-to-left modular binary exponentiation algorithm" is required knowledge for every C++ coder, I prefer to make fewer assumptions about the people who might read my code in the future. Especially if that person is me, coming back to the code after years away from it.

And now, the code, as I would prefer to see the current functionality written:

// Returns `(base**exp) % modulus`, where `**` denotes exponentiation.
// Assumes `exp` is non-negative.
// Assumes `modulus` is non-zero.
// If `exp` is zero, assumes `modulus` is neither 1 nor -1.
long long int modpow(long long int base, long long int exp, long long int modulus)
{
    // NOTE: This algorithm is known as the "right-to-left binary method" of
    // "modular exponentiation".

    // Throughout, we'll keep numbers smallish by using `(A*B) % C == ((A%C)*B) % C`.
    // The first application of this principle is to the base.
    base %= modulus;
    // Intermediate results will be stored modulo `modulus`.
    long long int result = 1;

    // Loop invariant:
    //     The value to return is `(result * base**exp) % modulus`.
    // Loop goal:
    //     Reduce `exp` to the point where `base**exp` is 1.
    while (exp > 0) {
        if ( is_odd(exp) ) {
            // Shift one factor of `base` to `result`:
            // `result * base^exp == (result*base) * base^(exp-1)`
            result = (result * base) % modulus;
            //--exp;  // logically happens, but optimized out.
            // We are now in the "`exp` is even" case.
        }
        // Reduce the exponent by increasing the base: `B**(2n) == (B**2)**n`.
        base = (base * base) % modulus;
        exp /= 2;
    }

    return result;
}

The resulting machine code is almost identical. If performance really is critical, I could see going back to exp >>= 1, but only if changing the type of exp is not allowed.