How to generate a sequence of indexes of a nested loop that has cache locality?

314 Views Asked by At

Imagine you have two arrays and you want to iterate a nested loop of them, such as the following:

#include <stdio.h>

int main(void) {
  int sizes[5] = {4, 4};
  for (int i = 0 ; i < sizes[0]; i++) {
    for (int j = 0 ; j < sizes[1]; j++) {
      printf("%d%d\n", i, j);
    }
  }
  return 0;
}

This produces the following pattern of indexes. The first number of each line is i and the second number is j.

00
01
02
03
10
11
12
13
20
21
22
23
30
31
32
33

The problem with this pattern is that it is bad for caches. The cache for j gets blown repeatedly. From a cache perspective all the data from 0-4 is loaded into the caches and thrown away before being cycled again on the next loop of i. I don't want the indexes to go back to 0, here is a similar sequence I want to produce:

00
10
30
20
21
31
11
01
03
13
33
23
22
32
12
02

Notice that each index doesn't repeat in a cycle. The indexes don't cycle. The indexes are all within 2 positions with eachother, so the cache line of memory can be used.

I produced this sequence with a reflected binary code or a gray code, inspired by this Hacker News comment

Here is the code I used to produce this sequence:

a = [1, 2, 3, 4]
b = [2, 4, 8, 16]
indexes = set()
correct = set()

print("graycode loop indexes")
for index in range(0, len(a) * len(b)):
  code = index ^ (index >> 1)
  left = code & 0x0003
  right = code >> 0x0002 & 0x0003

  print("{}{}".format(left, right))
  indexes.add((left, right))

assert len(indexes) == 16

print("regular nested loop indexes")

for x in range(0, len(a)):
  for y in range(0, len(b)):
    correct.add((x,y))
    print("{}{}".format(x, y))


assert correct == indexes

How do I generalize this sequence for varied length arrays?

For example, Imagine I have 3 lists of size 5, 8, 16 respectively, I want to iterate those lists in a nested loop but not blow the cache. I want to visit every index with as much in the cache as I can.

Said another way, for each I I want to go through every J but I want to do so in a sequence that does not result in repeating J.

1

There are 1 best solutions below

0
Samuel Squire On

It seems cache lines work with backwards and forwards locality. Jumping around memory randomly is inefficient but spatially local is efficient.

I produce indexes that maximize the use of what is in the cache line by iterating in reverse directions, alternating as the loop goes on.

This ensures the data at the index positions should be near eachother on the cache line.

a = [0, 1, 2, 3, 4]
b = [0, 1, 2, 3, 4]
c = [0, 1, 2, 3, 4]

efficient = set()
correct = set()
alternating = [False, False, False]
for i in range(0, len(a)):
  for j in range(0, len(b)):
    for k in range(0, len(c)):
      correct.add((i, j, k))

 
for i in range(0, len(a)):
  if alternating[1]:
    jstart = (len(b) - 1)
    jend = -1
    jstep = -1
  else:
    jstart = 0
    jend = len(b)
    jstep = 1
  for j in range(jstart, jend, jstep):
    if alternating[2]:
      kstart = (len(c) -1)
      kend = -1
      kstep = -1
    else:
      kstart = 0
      kend = len(c)
      kstep = 1
    for k in range(kstart, kend, kstep):
      
      new_item = (i, j, k)
      print(tuple(new_item))
      efficient.add(new_item)
    alternating[2] = not alternating[2]
  alternating[1] = True

print(len(correct))

print(len(efficient))



print("missing")
print(correct - efficient)
print("surplus")
print(efficient - correct)
assert correct == efficient

This produces the following indexes:

(0, 0, 0)
(0, 0, 1)
(0, 0, 2)
(0, 0, 3)
(0, 0, 4)
(0, 1, 4)
(0, 1, 3)
(0, 1, 2)
(0, 1, 1)
(0, 1, 0)
(0, 2, 0)
(0, 2, 1)
(0, 2, 2)
(0, 2, 3)
(0, 2, 4)
(0, 3, 4)
(0, 3, 3)
(0, 3, 2)
(0, 3, 1)
(0, 3, 0)
(0, 4, 0)
(0, 4, 1)
(0, 4, 2)
(0, 4, 3)
(0, 4, 4)
(1, 4, 4)
(1, 4, 3)
(1, 4, 2)
(1, 4, 1)
(1, 4, 0)
(1, 3, 0)
(1, 3, 1)
(1, 3, 2)
(1, 3, 3)
(1, 3, 4)
(1, 2, 4)
(1, 2, 3)
(1, 2, 2)
(1, 2, 1)
(1, 2, 0)
(1, 1, 0)
(1, 1, 1)
(1, 1, 2)
(1, 1, 3)
(1, 1, 4)
(1, 0, 4)
(1, 0, 3)
(1, 0, 2)
(1, 0, 1)
(1, 0, 0)
(2, 4, 0)
(2, 4, 1)
(2, 4, 2)
(2, 4, 3)
(2, 4, 4)
(2, 3, 4)
(2, 3, 3)
(2, 3, 2)
(2, 3, 1)
(2, 3, 0)
(2, 2, 0)
(2, 2, 1)
(2, 2, 2)
(2, 2, 3)
(2, 2, 4)
(2, 1, 4)
(2, 1, 3)
(2, 1, 2)
(2, 1, 1)
(2, 1, 0)
(2, 0, 0)
(2, 0, 1)
(2, 0, 2)
(2, 0, 3)
(2, 0, 4)
(3, 4, 4)
(3, 4, 3)
(3, 4, 2)
(3, 4, 1)
(3, 4, 0)
(3, 3, 0)
(3, 3, 1)
(3, 3, 2)
(3, 3, 3)
(3, 3, 4)
(3, 2, 4)
(3, 2, 3)
(3, 2, 2)
(3, 2, 1)
(3, 2, 0)
(3, 1, 0)
(3, 1, 1)
(3, 1, 2)
(3, 1, 3)
(3, 1, 4)
(3, 0, 4)
(3, 0, 3)
(3, 0, 2)
(3, 0, 1)
(3, 0, 0)
(4, 4, 0)
(4, 4, 1)
(4, 4, 2)
(4, 4, 3)
(4, 4, 4)
(4, 3, 4)
(4, 3, 3)
(4, 3, 2)
(4, 3, 1)
(4, 3, 0)
(4, 2, 0)
(4, 2, 1)
(4, 2, 2)
(4, 2, 3)
(4, 2, 4)
(4, 1, 4)
(4, 1, 3)
(4, 1, 2)
(4, 1, 1)
(4, 1, 0)
(4, 0, 0)
(4, 0, 1)
(4, 0, 2)
(4, 0, 3)
(4, 0, 4)
125
125
missing
set()
surplus
set()