What is the algorithm behind math.gcd and why it is faster Euclidean algorithm?

112 Views Asked by At

Tests shows that Python's math.gcd is one order faster than naive Euclidean algorithm implementation:

import math
from timeit import default_timer as timer

def gcd(a,b):
        while b != 0:
                a, b = b, a % b
        return a

def main():
        a = 28871271685163
        b = 17461204521323
        start = timer()
        print(gcd(a, b))
        end = timer()
        print(end - start)

        start = timer()
        print(math.gcd(a, b))
        end = timer()
        print(end - start)

gives

$ python3 test.py
1
4.816000000573695e-05
1
8.346003596670926e-06

e-05 vs e-06.

I guess there is some optimizations or some other algorithm?

1

There are 1 best solutions below

3
On BEST ANSWER

math.gcd() is certainly a Python shim over a library function that is running as machine code (i.e. compiled from "C" code), not a function being run by the Python interpreter. See also: Where are math.py and sys.py?

This should be it (for CPython):

math_gcd(PyObject *module, PyObject * const *args, Py_ssize_t nargs)

in mathmodule.c

and it calls

_PyLong_GCD(PyObject *aarg, PyObject *barg)

in longobject.c

which apparently uses Lehmer's GCD algorithm

The code is smothered in housekeeping operations and handling of special case though, increasing the complexity considerably. Still, quite clean.

PyObject *
_PyLong_GCD(PyObject *aarg, PyObject *barg)
{
    PyLongObject *a, *b, *c = NULL, *d = NULL, *r;
    stwodigits x, y, q, s, t, c_carry, d_carry;
    stwodigits A, B, C, D, T;
    int nbits, k;
    digit *a_digit, *b_digit, *c_digit, *d_digit, *a_end, *b_end;

    a = (PyLongObject *)aarg;
    b = (PyLongObject *)barg;
    if (_PyLong_DigitCount(a) <= 2 && _PyLong_DigitCount(b) <= 2) {
        Py_INCREF(a);
        Py_INCREF(b);
        goto simple;
    }

    /* Initial reduction: make sure that 0 <= b <= a. */
    a = (PyLongObject *)long_abs(a);
    if (a == NULL)
        return NULL;
    b = (PyLongObject *)long_abs(b);
    if (b == NULL) {
        Py_DECREF(a);
        return NULL;
    }
    if (long_compare(a, b) < 0) {
        r = a;
        a = b;
        b = r;
    }
    /* We now own references to a and b */

    Py_ssize_t size_a, size_b, alloc_a, alloc_b;
    alloc_a = _PyLong_DigitCount(a);
    alloc_b = _PyLong_DigitCount(b);
    /* reduce until a fits into 2 digits */
    while ((size_a = _PyLong_DigitCount(a)) > 2) {
        nbits = bit_length_digit(a->long_value.ob_digit[size_a-1]);
        /* extract top 2*PyLong_SHIFT bits of a into x, along with
           corresponding bits of b into y */
        size_b = _PyLong_DigitCount(b);
        assert(size_b <= size_a);
        if (size_b == 0) {
            if (size_a < alloc_a) {
                r = (PyLongObject *)_PyLong_Copy(a);
                Py_DECREF(a);
            }
            else
                r = a;
            Py_DECREF(b);
            Py_XDECREF(c);
            Py_XDECREF(d);
            return (PyObject *)r;
        }
        x = (((twodigits)a->long_value.ob_digit[size_a-1] << (2*PyLong_SHIFT-nbits)) |
             ((twodigits)a->long_value.ob_digit[size_a-2] << (PyLong_SHIFT-nbits)) |
             (a->long_value.ob_digit[size_a-3] >> nbits));

        y = ((size_b >= size_a - 2 ? b->long_value.ob_digit[size_a-3] >> nbits : 0) |
             (size_b >= size_a - 1 ? (twodigits)b->long_value.ob_digit[size_a-2] << (PyLong_SHIFT-nbits) : 0) |
             (size_b >= size_a ? (twodigits)b->long_value.ob_digit[size_a-1] << (2*PyLong_SHIFT-nbits) : 0));

        /* inner loop of Lehmer's algorithm; A, B, C, D never grow
           larger than PyLong_MASK during the algorithm. */
        A = 1; B = 0; C = 0; D = 1;
        for (k=0;; k++) {
            if (y-C == 0)
                break;
            q = (x+(A-1))/(y-C);
            s = B+q*D;
            t = x-q*y;
            if (s > t)
                break;
            x = y; y = t;
            t = A+q*C; A = D; B = C; C = s; D = t;
        }

        if (k == 0) {
            /* no progress; do a Euclidean step */
            if (l_mod(a, b, &r) < 0)
                goto error;
            Py_SETREF(a, b);
            b = r;
            alloc_a = alloc_b;
            alloc_b = _PyLong_DigitCount(b);
            continue;
        }

        /*
          a, b = A*b-B*a, D*a-C*b if k is odd
          a, b = A*a-B*b, D*b-C*a if k is even
        */
        if (k&1) {
            T = -A; A = -B; B = T;
            T = -C; C = -D; D = T;
        }
        if (c != NULL) {
            assert(size_a >= 0);
            _PyLong_SetSignAndDigitCount(c, 1, size_a);
        }
        else if (Py_REFCNT(a) == 1) {
            c = (PyLongObject*)Py_NewRef(a);
        }
        else {
            alloc_a = size_a;
            c = _PyLong_New(size_a);
            if (c == NULL)
                goto error;
        }

        if (d != NULL) {
            assert(size_a >= 0);
            _PyLong_SetSignAndDigitCount(d, 1, size_a);
        }
        else if (Py_REFCNT(b) == 1 && size_a <= alloc_b) {
            d = (PyLongObject*)Py_NewRef(b);
            assert(size_a >= 0);
            _PyLong_SetSignAndDigitCount(d, 1, size_a);
        }
        else {
            alloc_b = size_a;
            d = _PyLong_New(size_a);
            if (d == NULL)
                goto error;
        }

        a_end = a->long_value.ob_digit + size_a;
        b_end = b->long_value.ob_digit + size_b;

        /* compute new a and new b in parallel */
        a_digit = a->long_value.ob_digit;
        b_digit = b->long_value.ob_digit;
        c_digit = c->long_value.ob_digit;
        d_digit = d->long_value.ob_digit;
        c_carry = 0;
        d_carry = 0;
        while (b_digit < b_end) {
            c_carry += (A * *a_digit) - (B * *b_digit);
            d_carry += (D * *b_digit++) - (C * *a_digit++);
            *c_digit++ = (digit)(c_carry & PyLong_MASK);
            *d_digit++ = (digit)(d_carry & PyLong_MASK);
            c_carry >>= PyLong_SHIFT;
            d_carry >>= PyLong_SHIFT;
        }
        while (a_digit < a_end) {
            c_carry += A * *a_digit;
            d_carry -= C * *a_digit++;
            *c_digit++ = (digit)(c_carry & PyLong_MASK);
            *d_digit++ = (digit)(d_carry & PyLong_MASK);
            c_carry >>= PyLong_SHIFT;
            d_carry >>= PyLong_SHIFT;
        }
        assert(c_carry == 0);
        assert(d_carry == 0);

        Py_INCREF(c);
        Py_INCREF(d);
        Py_DECREF(a);
        Py_DECREF(b);
        a = long_normalize(c);
        b = long_normalize(d);
    }
    Py_XDECREF(c);
    Py_XDECREF(d);

simple:
    assert(Py_REFCNT(a) > 0);
    assert(Py_REFCNT(b) > 0);
/* Issue #24999: use two shifts instead of ">> 2*PyLong_SHIFT" to avoid
   undefined behaviour when LONG_MAX type is smaller than 60 bits */
#if LONG_MAX >> PyLong_SHIFT >> PyLong_SHIFT

    /* a fits into a long, so b must too */
    x = PyLong_AsLong((PyObject *)a);
    y = PyLong_AsLong((PyObject *)b);
#elif LLONG_MAX >> PyLong_SHIFT >> PyLong_SHIFT
    x = PyLong_AsLongLong((PyObject *)a);
    y = PyLong_AsLongLong((PyObject *)b);
#else
# error "_PyLong_GCD"
#endif
    x = Py_ABS(x);
    y = Py_ABS(y);
    Py_DECREF(a);
    Py_DECREF(b);

    /* usual Euclidean algorithm for longs */
    while (y != 0) {
        t = y;
        y = x % y;
        x = t;
    }
#if LONG_MAX >> PyLong_SHIFT >> PyLong_SHIFT
    return PyLong_FromLong(x);
#elif LLONG_MAX >> PyLong_SHIFT >> PyLong_SHIFT
    return PyLong_FromLongLong(x);
#else
# error "_PyLong_GCD"
#endif

error:
    Py_DECREF(a);
    Py_DECREF(b);
    Py_XDECREF(c);
    Py_XDECREF(d);
    return NULL;
}