BLAS tensor contractions for two indexes together

64 Views Asked by At

I want to calculate D[a,d] = A[a,b,c] * B[b,c,d].

Method I: reshape A[a,b,c] => C1[a,e], B[b,c,d] => C2[e,d], e = b*c

Method II: directly call dgemm. This is a run-time error.

" na, nb, nc, nd ? 2 3 5 7 Time for reshaping method 2.447600000000000E-002

Intel MKL ERROR: Parameter 10 was incorrect on entry to DGEMM . Time for straight method 1.838800000000000E-002 Difference between result matrices 5.46978468774136 "

Question: Can we contract two indexes together by BLAS?

The following only works for one index. How to speed up reshape in higher rank tensor contraction by BLAS in Fortran?

Program reshape_for_blas

  Use, Intrinsic :: iso_fortran_env, Only :  wp => real64, li => int64

  Implicit None

  Real( wp ), Dimension( :, :, : ), Allocatable :: a
  Real( wp ), Dimension( :, :, : ), Allocatable :: b
  Real( wp ), Dimension( :, : ), Allocatable :: c1, c2
  Real( wp ), Dimension( :, :    ), Allocatable :: d
  Real( wp ), Dimension( :, :    ), Allocatable :: e

  Integer :: na, nb, nc, nd, ne
  
  Integer( li ) :: start, finish, rate

  Write( *, * ) 'na, nb, nc, nd ?'
  Read( *, * ) na, nb, nc, nd
  ne = nb * nc
  Allocate( a ( 1:na, 1:nb, 1:nc ) ) 
  Allocate( b ( 1:nb, 1:nc, 1:nd ) ) 
  Allocate( c1( 1:na, 1:ne ) ) 
  Allocate( c2( 1:ne, 1:nd ) ) 
  Allocate( d ( 1:na, 1:nd ) ) 
  Allocate( e ( 1:na, 1:nd ) ) 

  ! Set up some data
  Call Random_number( a )
  Call Random_number( b )

  ! With reshapes
  Call System_clock( start, rate )
  c1 = Reshape( a, Shape( c1 ) )
  c2 = Reshape( b, Shape( c2 ) )
  Call dgemm( 'N', 'N', na, nd, ne, 1.0_wp, c1, Size( c1, Dim = 1 ), &
                                            c2, Size( c2, Dim = 1 ), &
                                    0.0_wp, e, Size( e, Dim = 1 ) )
  Call System_clock( finish, rate )
  Write( *, * ) 'Time for reshaping method ', Real( finish - start, wp ) / rate
  
  ! Direct
  Call System_clock( start, rate )
  Call dgemm( 'N', 'N', na, nd, ne, 1.0_wp, a , Size( a , Dim = 1 ), &
                                            b , Size( b , Dim = 1 ), &
                                            0.0_wp, d, Size( d, Dim = 1 ) )
  Call System_clock( finish, rate )
  Write( *, * ) 'Time for straight  method ', Real( finish - start, wp ) / rate

  Write( *, * ) 'Difference between result matrices ', Maxval( Abs( d - e ) )

End Program reshape_for_blas
0

There are 0 best solutions below