Sorting 3 numbers without branching

5k Views Asked by At

In C# or C++ how can I implement a branch-free sort of three (integer) numbers?

Is this possible?

4

There are 4 best solutions below

5
Lee Louviere On BEST ANSWER

No conditionals. Only a cast to uint. Perfect solution.

int abs (int a) 
{
    int b = a;
    b = (b >> (sizeof(int)*CHAR_BIT-1) & 1);
    return 2 * b * (a) + a; 
}
int max (int a, int b) { return (a + b + abs(a - b)) / 2; }
int min (int a, int b) { return (a + b - abs(a - b)) / 2; }


void sort (int & a, int & b, int & c)
{       
   int maxnum = max(max(a,b), c);
   int minnum = min(min(a,b), c);
   int middlenum = a + b + c - maxnum - minnum;
   a = maxnum;
   b = middlenum;
   c = minnum;
}
5
Flexo On

You can do this in C++ with:

#include <iostream>

void sort(int *in) {
  const int sum = in[0]+in[1];
  const int diff = abs(in[1]-in[0]);
  in[0] = (sum + diff) / 2;
  in[1] = (sum - diff) / 2;
}

int main() {
  int a[] = {3,4,1};
  sort(a);
  sort(a+1);
  sort(a);
  std::cout << a[0] << "," << a[1] << "," << a[2] << std::endl;

  int b[] = {1,2,3};
  sort(b);
  sort(b+1);
  sort(b);
  std::cout << b[0] << "," << b[1] << "," << b[2] << std::endl;
}

The trick is in expressing the min/max elements as arithmetic operations, not branching and then calling sort on pairs enough times to "bubble sort" them.


I've made a totally generic version, using template meta-programming to call sort the right number of times. It all gets inlined exactly as you'd hope with gcc 4.7.0 on my x86 box (although call is unconditional on x86 anyway). I've also implemented an abs function that avoids branches on x86 (it makes a few assumptions about integers that make it less portable, it's based on gcc's __builtin_abs implementation for x86 though):

#include <iostream>
#include <limits.h>

void myabs(int& in) {
  const int tmp = in >> ((sizeof(int) * CHAR_BIT) - 1);
  in ^= tmp;
  in = tmp - in;
}

template <int N, int I=1, bool C=false>
struct sorter {
  static void sort(int *in) {
    const int sum = in[I-0]+in[I-1];
    int diff = in[I-1]-in[I-0];
    myabs(diff);
    in[I-0] = (sum + diff) / 2;
    in[I-1] = (sum - diff) / 2;
    sorter<N, I+1, I+1>=N>::sort(in);
  }
};

template <int N,int I>
struct sorter<N,I,true> {
  static void sort(int *in) {
    sorter<N-1>::sort(in);
  }
};

template <int I, bool C>
struct sorter<0,I,C> {
  static void sort(int *) {
  }
};

int main() {
  int a[] = {3,4,1};
  sorter<3>::sort(a);
  std::cout << a[0] << "," << a[1] << "," << a[2] << std::endl;
}
6
Sarfaraz Nawaz On

You can write max, min and swap branch-free functions. Once you have these functions, you can use them to write sort function as:

void sort(int &a, int &b, int &c)
{
    int m1 = max(a,b,c);
    int m2 = min(a,b,c);
    b = a + b + c - m1 - m2;
    swap(m1, a);
    swap(m2, c);
}

And here are the helper functions:

void swap(int &a, int &b)
{
   int tmp = a; a = b; b = tmp;
}

int max( int a, int b, int c ) {
   int l1[] = { a, b };
   int l2[] = { l1[ a<b ], c };
   return l2[ l2[0] < c ];
}
int min( int a, int b, int c ) {
   int l1[] = { a, b };
   int l2[] = { l1[ a>b ], c };
   return l2[ l2[0] > c ];
}

Test code:

int main() {
        int a,b,c;
        std::cin >> a >> b >> c;
        sort(a,b,c);
        std::cout << a <<"," << b << "," << c << std::endl;
        return 0;
}

Input:

21 242 434

Output (descending order):

434, 242, 21

Demo : http://ideone.com/3ZOzc

I have taken the implementation of max from @David's answer from here, and implemented min with little twist.

0
tutizeri On

The problem with additions, is that they can overflow.

You can do it with XOR products, which do not overflow.

This works, because the XOR product is associative, commutative, and is his own inverse:

a = (a ^ b) ^ b

(int min, int middle, int max) Sort_3_numbers(int a, int b, int c)
{
    // Calculate the XOR of all values
    int all = a ^ b ^ c;// This will be calculated in parallel by out of order cpu

    // Find the minimum and maximum values 
    int _min = min(min(a, b), c);
    int _max = max(max(a, b), c);    

    // Extract the middle value
    int middle = all ^ _min ^ _max;

    return (_min, middle, _max);
}