Einsum multiply each row with every one for 3X3X3 array

92 Views Asked by At

Hello could someone please help me figure out how to use np.einsum to produce the below code's result. I have a (3,3,3) tensor and I will like to get this results which I got from using two for loops. the code I wrote to produce this output is below. I am trying to use np.einsum to produce this same result attained from using two for loops in the below code. I am not familar with using np.einsum. Ideally I will also like to sum each of the resulting rows to get nine values.


Command Line Arguments
result of code below   
[1 1 1]
[2 2 2]
[1 1 1]
[2 2 2]
[4 4 4]
[2 2 2]
[1 1 1]
[2 2 2]
[1 1 1]
[1 1 1]

3
6
3
9
12
6
15
18
9
6
12
6
18
24
12
import numpy as np
bb=[]
for x in range(3):
    for y in range(3):
        bb.append((x,y))
a = np.array([[[1,2,1],[3,4,2],[5,6,3]],
             [[1,2,1],[3,4,2],[5,6,3]],
             [[1,2,1],[3,4,2],[5,6,3]]])
b = np.array([[[1,2,1],[3,4,2],[5,6,3]],
             [[1,2,1],[3,4,2],[5,6,3]],
             [[1,2,1],[3,4,2],[5,6,3]]])
for z in range(9):
    llAI  = bb[z]
    aal = a[:,llAI[0],llAI[1]]
    for f in range(9):
        mmAI=bb[f]
        aam = a[:,mmAI[0],mmAI[1]]
        print(np.sum(aal*aam))
1

There are 1 best solutions below

3
On BEST ANSWER

It took a bit to figure out what you are doing,

Since z iterates on range(3), aal is successively a[:,0,0], a[:,0,1],a[:,0,2].

Or done all at once:

In [178]: aaL = a[:,0,:]; aaL
Out[178]: 
array([[1, 2, 1],
       [1, 2, 1],
       [1, 2, 1]])

aam does the same iteration. So the sum of their products, using matmul/@/dot is:

In [179]: aaL.T@aaL
Out[179]: 
array([[ 3,  6,  3],
       [ 6, 12,  6],
       [ 3,  6,  3]])

Or in einsum:

In [180]: np.einsum('ji,jk->ik',aaL,aaL)
Out[180]: 
array([[ 3,  6,  3],
       [ 6, 12,  6],
       [ 3,  6,  3]])

Your indexing array:

In [183]: bb
Out[183]: [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)
In [185]: np.array(bb)[:3,:]
Out[185]: 
array([[0, 0],
       [0, 1],
       [0, 2]])

If I generalize it to the remaining ranges of bb:

In [192]: for i in range(3):
     ...:     aaL = a[:,i]
     ...:     print(aaL.T@aaL)
     ...:     
[[ 3  6  3]
 [ 6 12  6]
 [ 3  6  3]]
[[27 36 18]
 [36 48 24]
 [18 24 12]]
[[ 75  90  45]
 [ 90 108  54]
 [ 45  54  27]]

Adding a dimension to the einsum:

In [195]: np.einsum('jmi,jmk->mik', a,a)
Out[195]: 
array([[[  3,   6,   3],
        [  6,  12,   6],
        [  3,   6,   3]],

       [[ 27,  36,  18],
        [ 36,  48,  24],
        [ 18,  24,  12]],

       [[ 75,  90,  45],
        [ 90, 108,  54],
        [ 45,  54,  27]]])