I am using darts to predict the Lorenz system, a 3D multivariate chaotic time series from 3 ordinary differential equations. I've created a TFT model which was trained on synthetic data - each of the datasets used are put into a series:
train_series = TimeSeries.from_group_dataframe(training_df,
group_cols = 'simNum',
time_col = 'time_idx',
value_cols = ['x','y','z'],
static_cols = ['sigma','rho','beta'])
where sigma, rho, and beta are constant, static covariates, bounded and randomized for each group (this particular attempt has 200 simulations). All columns of the input dataframe are floats (which are scaled using a Scaler
, except for time_idx
which is a time integer that resets for each group.
My goal is to incorporate the 3 ODEs of the Lorenz system into the loss function as a residual to be minimized, so I currently have:
import torch
import torch.nn as nn
def grad(outputs, inputs):
"""Computes the partial derivative of
an output with respect to an input."""
return torch.autograd.grad(
outputs,
inputs,
grad_outputs=torch.ones_like(outputs),
create_graph=True
)
def lorenz(x, y, z, s, r, b):
# Function modeling lorenz attractor system
x_dot = s * (y - x)
y_dot = x * (r - z) - y
z_dot = (x * y) - (b * z)
return x_dot, y_dot, z_dot
def odeLoss(model):
#defining collocation points
ts = torch.linspace(0, timeMax, TOTAL_SAMPLES).view(-1, 1).requires_grad_(True)
# Run collocation points through network
out = model.predict(ts, TOTAL_SAMPLES)
# Get gradient
dT = grad(out, ts)
# Compute ODE
odeOut = 0
# Return MSE of ODE
return torch.mean(odeOut**2)
I intend on adding this to the model parameters in addition to a data loss term using loss_fn = MSELoss() + odeLoss()
.
The odeLoss function is where I need help: I'm not sure how to generate model predictions over the collocation points or incorporate the ODEs into the function with the unknown static covariates. Do I simply generate 3 random values to define sigma, rho, and beta over the collocation points and integrate those into the model.predict
and the odeOut
computation? I'm at an impasse and not sure of how to proceed. I believe the darts library uses pytorch and pytorch lightning, so I might try defining a subclass that inherits a trainer.
The odeLoss function is where I need help: I'm not sure how to generate model predictions over the collocation points or incorporate the ODEs into the function with the unknown static covariates. Do I simply generate 3 random values to define sigma, rho, and beta over the collocation points and integrate those into the model.predict and the odeOut computation?
Here you go, just use the residuals from the functions as the loss: