I have 2 tensors, A and B:
A = torch.randn([32,128,64,12],dtype=torch.float64)
B = torch.randn([64,12,64,12],dtype=torch.float64)
C = torch.tensordot(A,B,([2,3],[0,1]))
D = C.permute(0,2,1,3) # shape:[32,64,128,12]
tensor D comes from the operations "tensordot -> permute". How can I implement a new operation f() to make the tensordot operation after f() like:
A_2 = f(A)
B_2 = f(B)
D = torch.tensordot(A_2,B_2)
Have you considered using
torch.einsumwhich is very flexible?The problem with
tensordotis that it outputs all dimensions ofAbefore those ofBand what you are looking for (when permuting) is to "interleave" dimensions fromAandB.