It seems there is a subtle difference between the JAX/PyTorch implementations of the BART transformer in the Hugging Face Transformers library. With JAX, the BART decoder can be initiated as a non-casual decoder, but not in PyTorch.
How can I initiate the BART decoder as a non-casual decoder with PyTorch?