How to implement custom loss function of ODEs with static covariates in Darts?

168 Views Asked by At

I am using darts to predict the Lorenz system, a 3D multivariate chaotic time series from 3 ordinary differential equations. I've created a TFT model which was trained on synthetic data - each of the datasets used are put into a series:

train_series = TimeSeries.from_group_dataframe(training_df, 
                                group_cols = 'simNum',
                                time_col = 'time_idx',
                                value_cols = ['x','y','z'],
                                static_cols = ['sigma','rho','beta'])

where sigma, rho, and beta are constant, static covariates, bounded and randomized for each group (this particular attempt has 200 simulations). All columns of the input dataframe are floats (which are scaled using a Scaler, except for time_idx which is a time integer that resets for each group.

My goal is to incorporate the 3 ODEs of the Lorenz system into the loss function as a residual to be minimized, so I currently have:

import torch
import torch.nn as nn

def grad(outputs, inputs):
    """Computes the partial derivative of 
    an output with respect to an input."""
    return torch.autograd.grad(
        outputs, 
        inputs, 
        grad_outputs=torch.ones_like(outputs), 
        create_graph=True
    )
def lorenz(x, y, z, s, r, b):
    # Function modeling lorenz attractor system
    x_dot = s * (y - x)
    y_dot = x * (r - z) - y
    z_dot = (x * y) - (b * z)
    return x_dot, y_dot, z_dot
def odeLoss(model):
    #defining collocation points
    ts = torch.linspace(0, timeMax, TOTAL_SAMPLES).view(-1, 1).requires_grad_(True)
    # Run collocation points through network
    out = model.predict(ts, TOTAL_SAMPLES)
    # Get gradient
    dT = grad(out, ts)
    # Compute ODE
    odeOut = 0
    # Return MSE of ODE
    return torch.mean(odeOut**2)

I intend on adding this to the model parameters in addition to a data loss term using loss_fn = MSELoss() + odeLoss().

The odeLoss function is where I need help: I'm not sure how to generate model predictions over the collocation points or incorporate the ODEs into the function with the unknown static covariates. Do I simply generate 3 random values to define sigma, rho, and beta over the collocation points and integrate those into the model.predict and the odeOut computation? I'm at an impasse and not sure of how to proceed. I believe the darts library uses pytorch and pytorch lightning, so I might try defining a subclass that inherits a trainer.

The odeLoss function is where I need help: I'm not sure how to generate model predictions over the collocation points or incorporate the ODEs into the function with the unknown static covariates. Do I simply generate 3 random values to define sigma, rho, and beta over the collocation points and integrate those into the model.predict and the odeOut computation?

1

There are 1 best solutions below

0
On

Here you go, just use the residuals from the functions as the loss:

    class LorenzLoss(nn.Module):
      def __init__(self, sigma, rho, beta, dt):
         super(LorenzLoss, self).__init__()
         self.sigma = sigma
         self.rho = rho
         self.beta = beta
         self.dt = dt

     def forward(self, y_pred, y_true):
    # Extracting the predicted values for x, y, and z
        x, y, z = y_pred[:, 0], y_pred[:, 1], y_pred[:, 2]
    
    # Calculating the residuals for the Lorenz system
        dx = self.sigma * (y - x) - (x - y_true[:, 0]) / self.dt
        dy = x * (self.rho - z) - y - (y - y_true[:, 1]) / self.dt
        dz = x * y - self.beta * z - (z - y_true[:, 2]) / self.dt
    
        # Summing up the residuals to compute the loss
        loss = torch.mean(dx**2 + dy**2 + dz**2)

        return loss