Signed integer comparison without comparison operators or widening

86 Views Asked by At

I am looking for a signed integer comparison function cmp(x: Int, y: Int) -> Int which does not use any comparison operators (<, <=, >, >=, <=>, etc.), does not use widening to a larger size of integer, and ideally only uses addition, subtraction, bitwise operators and the equality and inequality operators (==, !=).

I have tested the comparison operators here, here and here, but all of them fail for certain input values (here is a test program in Rust). The first two seem the most promising, and they fail for the same values, probably because of underflow in the subtraction.

2

There are 2 best solutions below

0
Falk Hüffner On

You didn't specify what cmp is actually supposed to do. I'm assuming it should calculate x > y as in the test code you link to, that is, this is not a three-way comparison as in some of the other questions you link to.

You can then use

(((x ^ y) >> 1) - ((x ^ y) & x)) < 0

(with arithmetic shift and - wrapping).

0
njuffa On

Here is a complete set of signed comparison primitives for int that I backported from corresponding SIMD primitives. Some inefficiencies are therefore likely, however gcc and clang optimize the code quite well; the latter better than the former. As needed in a SIMD context, there are cmp* functions that deliver a mask of all zeros or all ones, as well as set* functions that deliver a boolean result.

The code here is based on standard bit-twiddling techniques, in particular isolation of the sign bit, separate sum and carry bit vectors in addition, Mycroft's null-byte detection algorithm, and Montgomery's algorithm for averaging without overflow.

The ISO-C99 code below was developed with the Intel Classic compiler on an x86-64 platform and passed all tests incorporated in the code. I also compiled the code on Compiler Explorer with gcc and clang for x86-64, ARM32 and ARM64, an not issues were identified.

I noticed belatedly that asker appears to be actually interested in three-way comparison. Assuming a platform that uses twos-complement representation for int, this can easily be implemented as: int cmp (int a, int b) { return (int)(cmpne (a, b) & (cmples (a, b) | 1)); }

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <limits.h>

#define MAX_TEST        (10000000000LL)

#define CMPLES_VARIANT  (1)
#define CMPLTS_VARIANT  (1)

#define SIGN_SHIFT      ((int)((CHAR_BIT * sizeof (int) - 1)))
#define SIGN_MASK       ((unsigned int)(1ULL << SIGN_SHIFT))
#define ALL_ONES        ((unsigned int)(-1))
#define ALL_ZERO        ((unsigned int)0)

unsigned int sign_to_mask (unsigned int a)
{
    a = a & SIGN_MASK;
    a = a + a - (a >> SIGN_SHIFT);
    return a;
}

unsigned int sign_to_bool (unsigned int a)
{
    return a >> SIGN_SHIFT;
}

unsigned int bool_to_mask (unsigned int a)
{
    return ((a << SIGN_SHIFT) << 1) - a;
}

/* "signed less than" == (OF != SF) */
unsigned int lts_core (unsigned int a, unsigned int b)
{
    unsigned int t = (~a & ~SIGN_MASK) + (b & ~SIGN_MASK);
    return t ^ ((a ^ b) & (a ^ t)); // compute predicate in MSB (other bits X)
}

unsigned int haddu (unsigned int a, unsigned int b)
{
    /* Peter L. Montgomery's observation (newsgroup comp.arch, 2000/02/11,
       https://groups.google.com/d/msg/comp.arch/gXFuGZtZKag/_5yrz2zDbe4J):
       (A+B)/2 = (A AND B) + (A XOR B)/2.
    */
    return (a & b) + (((a ^ b) >> 1) & ~SIGN_MASK);
}

unsigned int setles (unsigned int a, unsigned int b)
{
    return sign_to_bool (haddu (a, ~b) ^ (a ^ ~b));
}

unsigned int setges (unsigned int a, unsigned int b)
{
    return setles (b, a);
}

unsigned int setlts (unsigned int a, unsigned int b)
{
    return sign_to_bool (lts_core (a, b));
}

unsigned int setgts (unsigned int a, unsigned int b)
{
    return setlts (b, a);
}

unsigned int seteq (unsigned int a, unsigned int b)
{
    unsigned int s, t;
    // inspired by Alan Mycroft's null-byte detection algorithm
    // (newsgroup comp.lang.c, 1987/04/08,
    // https://groups.google.com/forum/#!original/comp.lang.c/2HtQXvg7iKc/xOJeipH6KLMJ):
    // null_byte(x) = ((x - 0x01010101) & (~x & 0x80808080))
    s = a ^ b;
    t = s | SIGN_MASK;
    s = s ^ t;
    t = t - 1;
    t = s & ~t;
    return sign_to_bool (t);
}

unsigned int setne (unsigned int a, unsigned int b)
{
    unsigned int s, t;
    // inspired by Alan Mycroft's null-byte detection algorithm 
    // (newsgroup comp.lang.c, 1987/04/08,
    // https://groups.google.com/forum/#!original/comp.lang.c/2HtQXvg7iKc/xOJeipH6KLMJ):
    // null_byte(x) = ((x - 0x01010101) & (~x & 0x80808080))
    s = a ^ b;
    t = s | SIGN_MASK;
    t = t - 1;
    t = s | t;
    return sign_to_bool (t);
}

unsigned int cmples (unsigned int a, unsigned int  b)
{
#if (CMPLES_VARIANT == 0)
    return bool_to_mask (setles (a, b));
#else   
    return sign_to_mask (haddu (a, ~b) ^ (a ^ ~b));
#endif // CMPLES_VARIANT
}

unsigned int cmpges (unsigned int a, unsigned int b)
{
    return cmples (b, a);
}

unsigned int cmplts (unsigned int a, unsigned int b)
{
#if (CMPLTS_VARIANT == 0)
    return bool_to_mask (setlts (a, b));
#else   
    return sign_to_mask (lts_core (a, b));
#endif // CMPLTS_VARIANT
}

unsigned int cmpgts (unsigned int a, unsigned int b)
{
    return cmplts (b, a);
}

unsigned int cmpeq (unsigned int a, unsigned int b)
{
    unsigned int s, t;
    // inspired by Alan Mycroft's null-byte detection algorithm:
    // null_byte(x) = ((x - 0x01010101) & (~x & 0x80808080))
    s = a ^ b;
    t = s | SIGN_MASK;
    s = s ^ t;
    t = t - 1;
    t = s & ~t;
    return sign_to_mask (t);
}

unsigned int cmpne (unsigned int a, unsigned int b)
{
    unsigned int s, t;
    // inspired by Alan Mycroft's null-byte detection algorithm:
    // null_byte(x) = ((x - 0x01010101) & (~x & 0x80808080))
    s = a ^ b;
    t = s | SIGN_MASK;
    t = t - 1;
    t = s | t;
    return sign_to_mask (t);
}

int cmp (int a, int b) { return (int)(cmpne (a, b) & (cmples (a, b) | 1)); }

int setles_ref (int a, int b) { return (a <= b); }
int setlts_ref (int a, int b) { return (a < b);  }
int setges_ref (int a, int b) { return (a >= b); }
int setgts_ref (int a, int b) { return (a > b);  }
int seteq_ref  (int a, int b) { return (a == b); }
int setne_ref  (int a, int b) { return (a != b); }
int cmples_ref (int a, int b) { return (a <= b) ? ALL_ONES : ALL_ZERO; }
int cmplts_ref (int a, int b) { return (a < b)  ? ALL_ONES : ALL_ZERO; }
int cmpges_ref (int a, int b) { return (a >= b) ? ALL_ONES : ALL_ZERO; }
int cmpgts_ref (int a, int b) { return (a > b)  ? ALL_ONES : ALL_ZERO; }
int cmpeq_ref  (int a, int b) { return (a == b) ? ALL_ONES : ALL_ZERO; }
int cmpne_ref  (int a, int b) { return (a != b) ? ALL_ONES : ALL_ZERO; }
int cmp_ref    (int a, int b) { return (a == b) ? 0 : ((a < b) ? (-1) : 1); }


// George Marsaglia's KISS PRNG, period 2**123. Newsgroup sci.math, 21 Jan 1999
// Bug fix: Greg Rose, "KISS: A Bit Too Simple" http://eprint.iacr.org/2011/007
static uint32_t kiss_z=362436069, kiss_w=521288629;
static uint32_t kiss_jsr=123456789, kiss_jcong=380116160;
#define znew (kiss_z=36969*(kiss_z&65535)+(kiss_z>>16))
#define wnew (kiss_w=18000*(kiss_w&65535)+(kiss_w>>16))
#define MWC  ((znew<<16)+wnew )
#define SHR3 (kiss_jsr^=(kiss_jsr<<13),kiss_jsr^=(kiss_jsr>>17), \
              kiss_jsr^=(kiss_jsr<<5))
#define CONG (kiss_jcong=69069*kiss_jcong+1234567)
#define KISS ((MWC^CONG)+SHR3)

int main (void)
{
    int a, b, res, ref;
    uint64_t test;
    test = 0;
    do {
        a = KISS; b = KISS; test++;
        res = cmples (a, b);
        ref = cmples_ref (a, b);
        if (res != ref) {
            printf ("cmples FAILED: a=%08x b=%08x res=%08x ref=%08x\n", 
                    a, b, res, ref);
            return EXIT_FAILURE;
        }
    } while (test < MAX_TEST);
    printf ("cmples PASSED\n");
    test = 0;
    do {
        a = KISS; b = KISS; test++;
        res = cmplts (a, b);
        ref = cmplts_ref (a, b);
        if (res != ref) {
            printf ("cmplts FAILED: a=%08x b=%08x res=%08x ref=%08x\n", 
                    a, b, res, ref);
            return EXIT_FAILURE;
        }
    } while (test < MAX_TEST);
    printf ("cmplts PASSED\n");
    test = 0;
    do {
        a = KISS; b = KISS; test++;
        res = cmpges (a, b);
        ref = cmpges_ref (a, b);
        if (res != ref) {
            printf ("cmpges FAILED: a=%08x b=%08x res=%08x ref=%08x\n", 
                    a, b, res, ref);
            return EXIT_FAILURE;
        }
    } while (test < MAX_TEST);
    printf ("cmpges PASSED\n");
    test = 0;
    do {
        a = KISS; b = KISS; test++;
        res = cmpgts (a, b);
        ref = cmpgts_ref (a, b);
        if (res != ref) {
            printf ("cmpgts FAILED: a=%08x b=%08x res=%08x ref=%08x\n", 
                    a, b, res, ref);
            return EXIT_FAILURE;
        }
    } while (test < MAX_TEST);
    printf ("cmpgts PASSED\n");
    test = 0;
    do {
        a = KISS; b = KISS; test++;
        res = cmpeq (a, b);
        ref = cmpeq_ref (a, b);
        if (res != ref) {
            printf ("cmpeq  FAILED: a=%08x b=%08x res=%08x ref=%08x\n", 
                    a, b, res, ref);
            return EXIT_FAILURE;
        }
    } while (test < MAX_TEST);
    printf ("cmpeq  PASSED\n");
    test = 0;
    do {
        a = KISS; b = KISS; test++;
        res = cmpne (a, b);
        ref = cmpne_ref (a, b);
        if (res != ref) {
            printf ("cmpne FAILED: a=%08x b=%08x res=%08x ref=%08x\n", 
                    a, b, res, ref);
            return EXIT_FAILURE;
        }
    } while (test < MAX_TEST);
    printf ("cmpne  PASSED\n");
    test = 0;
    do {
        a = KISS; b = KISS; test++;
        res = cmp (a, b);
        ref = cmp_ref (a, b);
        if (res != ref) {
            printf ("cmp FAILED: a=%08x b=%08x res=%08x ref=%08x\n", 
                    a, b, res, ref);
            return EXIT_FAILURE;
        }
    } while (test < MAX_TEST);
    printf ("cmp    PASSED\n");

    test = 0;
    do {
        a = KISS; b = KISS; test++;
        res = setles (a, b);
        ref = setles_ref (a, b);
        if (res != ref) {
            printf ("setles FAILED: a=%08x b=%08x res=%08x ref=%08x\n", 
                    a, b, res, ref);
            return EXIT_FAILURE;
        }
    } while (test < MAX_TEST);
    printf ("setles PASSED\n");
    test = 0;
    do {
        a = KISS; b = KISS; test++;
        res = setlts (a, b);
        ref = setlts_ref (a, b);
        if (res != ref) {
            printf ("setlts FAILED: a=%08x b=%08x res=%08x ref=%08x\n", 
                    a, b, res, ref);
            return EXIT_FAILURE;
        }
    } while (test < MAX_TEST);
    printf ("setlts PASSED\n");
    test = 0;
    do {
        a = KISS; b = KISS; test++;
        res = setges (a, b);
        ref = setges_ref (a, b);
        if (res != ref) {
            printf ("setges FAILED: a=%08x b=%08x res=%08x ref=%08x\n", 
                    a, b, res, ref);
            return EXIT_FAILURE;
        }
    } while (test < MAX_TEST);
    printf ("setges PASSED\n");
    test = 0;
    do {
        a = KISS; b = KISS; test++;
        res = setgts (a, b);
        ref = setgts_ref (a, b);
        if (res != ref) {
            printf ("setgts FAILED: a=%08x b=%08x res=%08x ref=%08x\n", 
                    a, b, res, ref);
            return EXIT_FAILURE;
        }
    } while (test < MAX_TEST);
    printf ("setgts PASSED\n");
    test = 0;
    do {
        a = KISS; b = KISS; test++;
        res = seteq (a, b);
        ref = seteq_ref (a, b);
        if (res != ref) {
            printf ("seteq  FAILED: a=%08x b=%08x res=%08x ref=%08x\n", 
                    a, b, res, ref);
            return EXIT_FAILURE;
        }
    } while (test < MAX_TEST);
    printf ("seteq  PASSED\n");
    test = 0;
    do {
        a = KISS; b = KISS; test++;
        res = setne (a, b);
        ref = setne_ref (a, b);
        if (res != ref) {
            printf ("setne FAILED: a=%08x b=%08x res=%08x ref=%08x\n", 
                    a, b, res, ref);
            return EXIT_FAILURE;
        }
    } while (test < MAX_TEST);
    printf ("setne  PASSED\n");
    return EXIT_SUCCESS;
}