TypeError: max() received an invalid combination of arguments when trying to use beam search decoding

110 Views Asked by At

I'm trying to run simple example of decode WAV2VEC2 output with beam search (without LM):

from pyctcdecode       import build_ctcdecoder
from transformers      import Wav2Vec2ForCTC, Wav2Vec2Processor
from torchaudio.utils  import download_asset

import torch
import librosa

processor        = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model            = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

FILE_NAME        = "tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SPEECH_FILE      = download_asset(FILE_NAME)

speech, sr       = librosa.load(SPEECH_FILE, sr=16000)
input_values     = processor(speech, sampling_rate=16000, return_tensors="pt").input_values

logits           = model(input_values).logits
vocabulary       = list(processor.tokenizer.get_vocab().keys())
log_probs        = torch.nn.functional.log_softmax(logits[0])

decoder          = build_ctcdecoder(vocabulary)
text             = decoder.decode(log_probs)

I'm getting the error:

TypeError: max() received an invalid combination of arguments - got (keepdims=bool, out=NoneType, axis=int, ), but expected one of:
 * ()
 * (Tensor other)
 * (int dim, bool keepdim)
 * (name dim, bool keepdim)

As you can see I'm using pyctcdecode.

How can I decode the output of the wav2vec2 model with a beam search algorithm?

1

There are 1 best solutions below

0
Claudio P On BEST ANSWER

According to this issue on Github this error should be resolved by converting your logits to a numpy array:

logits = model(input_values).logits.numpy()