I want to calculate D[a,d] = A[a,b,c] * B[b,c,d].
Method I: reshape A[a,b,c] => C1[a,e], B[b,c,d] => C2[e,d], e = b*c
Method II: directly call dgemm. This is a run-time error.
" na, nb, nc, nd ? 2 3 5 7 Time for reshaping method 2.447600000000000E-002
Intel MKL ERROR: Parameter 10 was incorrect on entry to DGEMM . Time for straight method 1.838800000000000E-002 Difference between result matrices 5.46978468774136 "
Question: Can we contract two indexes together by BLAS?
The following only works for one index. How to speed up reshape in higher rank tensor contraction by BLAS in Fortran?
Program reshape_for_blas
Use, Intrinsic :: iso_fortran_env, Only : wp => real64, li => int64
Implicit None
Real( wp ), Dimension( :, :, : ), Allocatable :: a
Real( wp ), Dimension( :, :, : ), Allocatable :: b
Real( wp ), Dimension( :, : ), Allocatable :: c1, c2
Real( wp ), Dimension( :, : ), Allocatable :: d
Real( wp ), Dimension( :, : ), Allocatable :: e
Integer :: na, nb, nc, nd, ne
Integer( li ) :: start, finish, rate
Write( *, * ) 'na, nb, nc, nd ?'
Read( *, * ) na, nb, nc, nd
ne = nb * nc
Allocate( a ( 1:na, 1:nb, 1:nc ) )
Allocate( b ( 1:nb, 1:nc, 1:nd ) )
Allocate( c1( 1:na, 1:ne ) )
Allocate( c2( 1:ne, 1:nd ) )
Allocate( d ( 1:na, 1:nd ) )
Allocate( e ( 1:na, 1:nd ) )
! Set up some data
Call Random_number( a )
Call Random_number( b )
! With reshapes
Call System_clock( start, rate )
c1 = Reshape( a, Shape( c1 ) )
c2 = Reshape( b, Shape( c2 ) )
Call dgemm( 'N', 'N', na, nd, ne, 1.0_wp, c1, Size( c1, Dim = 1 ), &
c2, Size( c2, Dim = 1 ), &
0.0_wp, e, Size( e, Dim = 1 ) )
Call System_clock( finish, rate )
Write( *, * ) 'Time for reshaping method ', Real( finish - start, wp ) / rate
! Direct
Call System_clock( start, rate )
Call dgemm( 'N', 'N', na, nd, ne, 1.0_wp, a , Size( a , Dim = 1 ), &
b , Size( b , Dim = 1 ), &
0.0_wp, d, Size( d, Dim = 1 ) )
Call System_clock( finish, rate )
Write( *, * ) 'Time for straight method ', Real( finish - start, wp ) / rate
Write( *, * ) 'Difference between result matrices ', Maxval( Abs( d - e ) )
End Program reshape_for_blas