I'm doing 2D FFTs of 2D arrays of complex numbers using pyFFTW. These arrays can get very large (~128 GiB), so execution time is crucial. (The background is wavefront propagation in optical physics.)
Have a look at the following toy code:
import numpy as np
import pyfftw
import multiprocessing
a = np.random.rand(16384, 16384) + 1j*np.random.rand(16384, 16384)
fft = pyfftw.FFTW(a, a, axes = (0, 1), direction = 'FFTW_FORWARD', flags = ('FFTW_ESTIMATE', 'FFTW_UNALIGNED', 'FFTW_DESTROY_INPUT',), threads = multiprocessing.cpu_count())
a = fft()
The FFT takes some seconds to execute on my modern 64-bit machine.
Both the result and the execution time remain the same when doing the 2D FFT in two steps (1D-FFTs of all columns and of all rows):
fft = pyfftw.FFTW(a, a, axes = (0,), direction = 'FFTW_FORWARD', flags = ('FFTW_ESTIMATE', 'FFTW_UNALIGNED', 'FFTW_DESTROY_INPUT',), threads = multiprocessing.cpu_count())
a = fft()
fft = pyfftw.FFTW(a, a, axes = (1,), direction = 'FFTW_FORWARD', flags = ('FFTW_ESTIMATE', 'FFTW_UNALIGNED', 'FFTW_DESTROY_INPUT',), threads = multiprocessing.cpu_count())
a = fft()
However, taking the time of these steps individually shows that the column-FFT is roughly 10 times slower than the row-FFT.
The reason, I guess, is the fact that the array is saved to physical RAM row by row. Indeed, a.flags gives
C_CONTIGUOUS : True
F_CONTIGUOUS : False
OWNDATA : True
WRITEABLE : True
ALIGNED : True
UPDATEIFCOPY : False
while a.strides gives
(262144, 16)
So, the array is C-contiguous and seems to be correctly aligned. However, removing the flag 'FFTW_UNALIGNED' makes the column-FFT roughly another 10 times slower (while the row-FFT becomes slightly faster).
Hence, my questions is:
Might there be something wrong with the alignment or is a 10 times slower access to columns than to rows for C-contiguous arrays the physical limit?
EDIT: Indeed, a factor of 10 seems to be too large. Let's compare simple read/write access of rows and columns:
a[:,0:16384:2]*=1j
and
a[0:16384:2,:]*=1j
Multiplying the columns with an even index (first variant) is about 2 times slower than multiplying the rows with an even index (second variant).
EDIT: The exact code entered in ipython is
In [1]: import pyfftw
In [2]: import multiprocessing
In [3]: a = np.random.rand(16384, 16384) + 1j*np.random.rand(16384, 16384)
In [4]: fft = pyfftw.FFTW(a, a, axes = (0, 1), direction = 'FFTW_FORWARD', flags = ('FFTW_ESTIMATE', 'FFTW_UNALIGNED', 'FFTW_DESTROY_INPUT',), threads = multiprocessing.cpu_count())
In [5]: %timeit a = fft()