unroll numpy einsum to get indexs

34 Views Asked by At

I have a vector with shape [2, 2, 2, 2, 2] and I need to get the indexs "from" and "to" for this numpy einsum operation:

np.einsum(vector,[0, 1, 2, 3, 4], np.conj(vector),
         [0, 1, 2, 8, 9],[3, 4, 8, 9])

the result is [2, 2, 2, 2] shape.
Can someone unroll this in loops or explain how I can get the indexs from vector that affects indexs from result.

1

There are 1 best solutions below

0
Luis ALberto On

Looking at numpy einsum source code I fopund that is a tensor contraction operation that uses numbers for indexs instead of letters. So Is just a sum of multiplications. I made the a function that constructs the loops receiving the same parameters as einsum. It creates the code in a string an executes it with python exec. In this function I don't get the indexs but they can be easily accessed. Here is the code:

def myeinsum(sv,indexs1,indexs2,indexs3):
    svc=np.conj(sv)

    res_shape = [2] * len(indexs3)
    res=np.zeros(res_shape,dtype = np.complex128)    
    # create variable names for loops in idxsloops
    idxsloops=[]
    for i in range(len(indexs1)):
        idxsloops.append("x" +str(indexs1[i]))
        
    for i in range(len(indexs2)):
        if indexs2[i] not in indexs1:            
            idxsloops.append("x" +str(indexs2[i]))
                
    numloops= len(idxsloops)
    strcode=""
    for i in range(numloops):
        stabs=""
        for k in range(i): stabs+="\t"            
        strcode += stabs + "for " +str(idxsloops[i]) +" in range(2):\n"

    stridxres=""
    # create variable names for operand and result
    for i in range(len(indexs3)):
        stridxres +="[x"+str(indexs3[i]) +"]"

    strfactor1=""
    for i in range(len(indexs1)):
        strfactor1 +="[x"+str(indexs1[i]) +"]"
    strfactor2=""
    for i in range(len(indexs2)):
        strfactor2 +="[x"+str(indexs2[i]) +"]"
    strcode +=stabs + "\t" + "res " + stridxres +" += sv" + strfactor1 + " *  svc" + strfactor2 
    
    
    print(strcode)
    
    exec (strcode)
        
    return res