Spiral Loss Function for Music Encoding

98 Views Asked by At

I am trying to develop an autoencoder for music generation; in pursuit of that end, I am attempting to develop a loss function which captures musical relationships.

My current idea is a 'Spiral' loss function, which is to say that if the system predicts the same note in a different octave, the loss should be smaller than if the note is just wrong. Additionally, notes that are close to the correct note, such as B and D to C should also have small losses. One can conceptually think of this as finding the distance between two points on a coil or spiral, such that the same notes in different octaves lie along a line tangent to the coil, but separated by some loop distance.

I am working in PyTorch, and my input representation is a 36 by 36 Tensor, where the rows represent the notes (MIDI range 48:84, the middle three octaves of a piano) and the columns represent time steps (1 column = 1/100th of a second). The values in the matrix are either 0 or 1, signifying that a note is on at a particular time.

Here is my current implementation of the loss:

def SpiralLoss():
    def spiral_loss(input, output):
        loss = Variable(torch.FloatTensor([0]))
        d = 5
        r = 10
        for i in xrange(input.size()[0]):
            for j in xrange(input.size()[3]):
                # take along the 1 axis because it's a column vector
                inval, inind = torch.max(input[i, :, :, j], 1)
                outval, outind = torch.max(output[i, :, :, j], 1)
                note_loss = (r*30*(inind%12 - outind%12)).float()
                octave_loss = (d*(inind/12 - outind/12)).float()
                loss += torch.sqrt(torch.pow(note_loss, 2) + torch.pow(octave_loss, 2))
        return loss
    return spiral_loss

The problem with this loss is that the max function is not differentiable. I cannot think of a way to make this loss differentiable, and was wondering if anyone might have any ideas or suggestions?

I'm not sure if this is the right place for a post like this, and so if it isn't, I would really appreciate any pointers towards a better location.

1

There are 1 best solutions below

0
On

Taking the maximum here is not only problematic because of differentiability: If you only take the maximum of the output, and it is at the right place, slightly lower values in wrong positions don't get punished.

One rough idea would be to use a normal L1 or L2 loss on the difference of the input and a modified output vector: The output could be multiplied by some weight mask that punishes octave and note difference differently, like for example:

def create_mask(input_column):
    r = 10
    d = 5
    mask = torch.FloatTensor(input_column.size())
    _, max_ind = torch.max(input_column, 0)
    max_ind = int(max_ind[0])
    for i in range(mask.size(0)):
        mask[i] = r*abs(i-max_ind)%12 + d*abs(i-max_ind)/12
    return mask

This is just roughly written, not something ready but in theory it should do the job. The mask vector should be set to requires_grad=False since it is an exact constant we compute for each input. Thus, you can use the maximum on the input but don't use the max on the output.

I hope it helps!