RISC-V Implement a discrete function without branching

618 Views Asked by At

Consider the discrete-valued function f defined on integers in the set {-3, -2, -1, 0, 1, 2, 3}. Definition:

f(-3) = 6
f(-2) = 61
f(-1) = 17
f(0) = -38
f(1) = 19
f(2) = 42
f(3) = 5

I am trying to implement a function "f" in RISC V that does not use branchings nor jumps to evaluate the function. a0 will be the value to evaluate in f, [ essentially: f(a0) ], while a1 will point to the address of the output array (see code below).

My code so far:

.globl f

.data
neg3:   .asciiz "f(-3) should be 6, and it is: "
neg2:   .asciiz "f(-2) should be 61, and it is: "
neg1:   .asciiz "f(-1) should be 17, and it is: "
zero:   .asciiz "f(0) should be -38, and it is: "
pos1:   .asciiz "f(1) should be 19, and it is: "
pos2:   .asciiz "f(2) should be 42, and it is: "
pos3:   .asciiz "f(3) should be 5, and it is: "

output: .word   6, 61, 17, -38, 19, 42, 5
.text
main:
    la a0, neg3
    jal print_str
    li a0, -3
    la a1, output
    jal f               # evaluate f(-3); should be 6
    jal print_int
    jal print_newline

    la a0, neg2
    jal print_str
    li a0, -2
    la a1, output
    jal f               # evaluate f(-2); should be 61
    jal print_int
    jal print_newline

    la a0, neg1
    jal print_str
    li a0, -1
    la a1, output
    jal f               # evaluate f(-1); should be 17
    jal print_int
    jal print_newline

    la a0, zero
    jal print_str
    li a0, 0
    la a1, output
    jal f               # evaluate f(0); should be -38
    jal print_int
    jal print_newline

    la a0, pos1
    jal print_str
    li a0, 1
    la a1, output
    jal f               # evaluate f(1); should be 19
    jal print_int
    jal print_newline

    la a0, pos2
    jal print_str
    li a0, 2
    la a1, output
    jal f               # evaluate f(2); should be 42
    jal print_int
    jal print_newline

    la a0, pos3
    jal print_str
    li a0, 3
    la a1, output
    jal f               # evaluate f(3); should be 5
    jal print_int
    jal print_newline

    li a0, 10
    ecall

# f takes in two arguments:
# a0 is the value we want to evaluate f at
# a1 is the address of the "output" array (defined above).
# 
f:
    
    #store the value of a0 in temp register
    add t0, a0, x0
    
    #store numerical value to get correct output index
    addi t1, x0, 3
    
    #add the numerical value 3 to the argument value 
    #to get the correct output array index value
    add t2, t1, t0
    
    #store size of int
    addi t3, x0, 4
    
    #index the output array
    mul t4, t2, t3
    add t4, t4, t2
    
    lw ra, 0(t4)
    

    jr ra            

print_int:
    mv a1, a0
    li a0, 1
    ecall
    jr    ra

print_str:
    mv a1, a0
    li a0, 4
    ecall
    jr    ra

print_newline:
    li a1, '\n'
    li a0, 11
    ecall
    jr    ra

It is clear that adding the value of 3 to the passed in argument a0 will index correctly to the output array. I am unsure if I did this correctly in my code and would like some advice.

1

There are 1 best solutions below

0
Lukas On

I can spot two bugs in your implementation of f:

  1. add t4, t4, t2: Here you are adding the address offset in the array t4 = (a0 + 3) * 4 to the index in the array (t2 = a0 + 3), when instead you want to add the offset and the array pointer (a1):
    add t4, t4, a1.
  2. lw ra, 0(t4): Here you are loading the value from the array into the return address register ra, which holds the address of the instruction to jump back (return) to in the calling function, and is used by jr ra.
    Instead you want to load the value into a0, which is the canonical return value register:
    lw a0, 0(t4).

There are some possible improvements:

    #store the value of a0 in temp register
    add t0, a0, x0

    #store numerical value to get correct output index
    addi t1, x0, 3
    
    #add the numerical value 3 to the argument value 
    #to get the correct output array index value
    add t2, t1, t0
  1. There is no reason to copy a0 to t0, just use a0 directly in the last instruction.
  2. There's also no reason to load the immediate value 3 into t1, just turn the last instruction into an addi.

Combined, the replacement for the above three instructions is this one instruction:

    # add 3 to argument value
    addi t2, a0, 3

There is more:

    #store size of int
    addi t3, x0, 4
    
    #index the output array
    mul t4, t2, t3
    add t4, t4, a1     # t2 replaced by a1, see above

You may have noticed that there is no muli instruction, so the immediate value 4 must be loaded into register t3 before multiplying.
Furthermore, mul is part of the M-extension (not supported by all processors), and multiplication is generally a complicated, slow operation.

Therefore I propose to replace multiplication by 4 with a left shift by 2, which exploits the binary representation of integers and produces the same result - at least for unsigned values.
The computed index should always be >= 0, so this is safe.

There is a left shift immediate instruction slli, which is part of all base RISC-V ISAs, and shifting is usually faster than multiplication.
The three instructions above can be replaced by just two:

    # index the output array
    # (multiply by 4 / left shift by 2, then add array address)
    slli t4, t2, 2
    add  t4, t4, a1     # t2 replaced by a1, see above

Now combining all bug fixes and improvements, the whole function takes just five instructions instead of eight:

# a0: argument value ; a1: array address
f:
    # add 3 to argument value
    addi t2, a0, 3

    # index the output array
    # (multiply by 4 / left shift by 2, then add array address)
    slli t4, t2, 2
    add  t4, t4, a1     # t2 replaced by a1, see above

    # load value from output array
    lw a0, 0(t4)        # ra replaced by a0, see above

    # return to calling function
    jr ra

Now, let's take a final look at what a modern compiler like clang 11 with optimizations can produce (godbolt.org) when we give it the same function implemented in C:

#include <stdint.h>
int32_t f(int32_t a0, int32_t a1[]) {
    return a1[a0 + 3];
}

Here's the result with some comments I added:

f:
    # multiply argument value by 4
    slli a0, a0, 2

    # add array address
    add  a0, a0, a1

    # load array value at three places (= 12 bytes) further
    lw   a0, 12(a0)

    # return to calling function
    ret

This is only four instructions, let's discuss this:

  1. ret is just a short-hand notation for jr ra, same as what we have.
  2. Register a0 is reused by all instructions. This is perfectly legal.
    A benefit of using a0 is that this register can be used by the compressed instructions of the C-extension. The tN registers cannot.
  3. The compiler used the same trick I proposed to replace addi and mul with slli.
  4. The initial addi t2, a0, 3 is missing.
    Instead, the compiler internally multiplied the constants 3 and 4 and used the resulting value 12 as the immediate value in lw a0, 12(a0).
    The original address calculation was (a0 + 3) * 4 + a1. The compiler transformed this into a0 * 4 + a1 + 12, which is equivalent. Then it moves the addition of 12 into the address generation step of the lw instruction, thereby saving the addi instruction.