Matching Torch STFT with Accelerate

299 Views Asked by At

Im trying to re-implement Torch's STFT code in Swift with Accelerate / vDSP, to produce a Log Mel Spectrogram by post processing the STFT so I can use the Mel Spectrogram as an input for a CoreML port of OpenAI's Whisper

Pytorch's native STFT / Mel code produces this Spectrogram (its clipped due to importing raw float 32s into Photoshop lol)

enter image description here

and mine:

enter image description here

Obviously the two things to notice are the values, and the lifted frequency components.

The STFT Docs here https://pytorch.org/docs/stable/generated/torch.stft.html

X[ω,m]= 
k=0
∑
win_length-1
​   
window[k] input[m×hop_length+k] * exp(−j * (2π⋅ωk) /win_length)

I believe Im properly handling window[k] input[m×hop_length+k] but I'm a bit lost as to how to calculate the exponent and what -J is referring to in the documentation, and how to convert the final exponential in vDSP. Also, if its a sum, how do I get the 200 elements I need!?

My Log Mel Spectrogram

My code follows:

  func processData(audio: [Int16]) -> [Float]
    {
        assert(self.sampleCount == audio.count)
            
        var audioFloat:[Float] = [Float](repeating: 0, count: audio.count)
                
        vDSP.convertElements(of: audio, to: &audioFloat)
        
        vDSP.divide(audioFloat, 32768.0, result: &audioFloat)
        
        // Up to this point, Python and swift are numerically identical
             
        // insert numFFT/2 samples before and numFFT/2 after so we have a extra numFFT amount to process
        // TODO: Is this stricly necessary?
        audioFloat.insert(contentsOf: [Float](repeating: 0, count: self.numFFT/2), at: 0)
        audioFloat.append(contentsOf: [Float](repeating: 0, count: self.numFFT/2))

        // Split Complex arrays holding the FFT results
        var allSampleReal = [[Float]](repeating: [Float](repeating: 0, count: self.numFFT/2), count: self.melSampleCount)
        var allSampleImaginary = [[Float]](repeating: [Float](repeating: 0, count: self.numFFT/2), count: self.melSampleCount)

        
        // Step 2 - we need to create 200 x 3000 matrix of STFTs - note we appear to want to output complex numbers (?)
        for (m) in 0 ..< self.melSampleCount
        {
            // Slice numFFTs every hop count (barf) and make a mel spectrum out of it
            // audioFrame ends up holding split complex numbers
            var audioFrame = Array<Float>( audioFloat[ (m * self.hopCount) ..< ( (m * self.hopCount) + self.numFFT) ] )
            
            // Copy of audioFrame original samples
            let audioFrameOriginal = audioFrame
            
            assert(audioFrame.count == self.numFFT)
            
            // Split Complex arrays holding a single FFT result of our Audio Frame, which gets appended to the allSample Split Complex arrays
            var sampleReal:[Float] = [Float](repeating: 0, count: self.numFFT/2)
            var sampleImaginary:[Float] = [Float](repeating: 0, count: self.numFFT/2)

            sampleReal.withUnsafeMutableBytes { unsafeReal in
                sampleImaginary.withUnsafeMutableBytes { unsafeImaginary in
                    
                    vDSP.multiply(audioFrame,
                                  hanningWindow,
                                  result: &audioFrame)

                    var complexSignal = DSPSplitComplex(realp: unsafeReal.bindMemory(to: Float.self).baseAddress!,
                                                        imagp: unsafeImaginary.bindMemory(to: Float.self).baseAddress!)
                           
                    audioFrame.withUnsafeBytes { unsafeAudioBytes in
                        vDSP.convert(interleavedComplexVector: [DSPComplex](unsafeAudioBytes.bindMemory(to: DSPComplex.self)),
                                     toSplitComplexVector: &complexSignal)
                    }
                    
                    // Step 3 - creating the FFT
                    self.fft.forward(input: complexSignal, output: &complexSignal)                    
                }
            }

            // We need to match: https://pytorch.org/docs/stable/generated/torch.stft.html
            // At this point, I'm unsure how to continue?
            
//            let twoπ = Float.pi * 2
//            let freqstep:Float = Float(16000 / (self.numFFT/2))
//
//            var w:Float = 0.0
//            for (k) in 0 ..< self.numFFT/2
//            {
//                let j:Float = sampleImaginary[k]
//                let sample = audioFrame[k]
//
//                let exponent = -j * ( (twoπ * freqstep * Float(k) ) / Float((self.numFFT/2)))
//
//                w += powf(sample, exponent)
//            }
            
            
            allSampleReal[m] = sampleReal
            allSampleImaginary[m] = sampleImaginary
        }
        
        // We now have allSample Split Complex holding 3000  200 dimensional real and imaginary FFT results
        
        // We create flattened  3000 x 200 array of DSPSplitComplex values
        var flattnedReal:[Float] = allSampleReal.flatMap { $0 }
        var flattnedImaginary:[Float] = allSampleImaginary.flatMap { $0 }

0

There are 0 best solutions below