Forecasting using fable and future, time & memory issues

495 Views Asked by At

I'm using fable and future to try to forecast in parallel, unfortunately it seems that for each iteration in the for loop, the model() step takes more time and consumes more memory. What I am trying to do is step forward one week at a time and forecast a few weeks at each step using possibly multiple models at once.

The size of the data I am passing to the model() function is less than 1% larger each step but the time it takes to compute keeps growing exponentially. Below is a simplified example, in my case I do some computations on the values up to that point and pass it to model which makes this increase in compute time at each model() call only increase.

I did some investigation and the time increase seems to be coming from this line in fabletools. I ran the debug option in the future package and the relevant code that is increasing in compute time there is this.

I believe that more data than is needed is being passed to each cluster node for each subsequent iteration of the loop. Is there a way that this can be avoided and ensure that only cur_training_data is passed down the stack?

Alternatively, maybe my whole strategy for doing this is off, I saw tsibble_stretch that could be a way to do this but I fear that replicating the training data for each time step would increase the footprint by a lot, this is why I went with the loop and filter. Is there a better way to do this in general?

Thanks very much for reading.

library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(dtplyr)
library(tidyr)
library(tsibbledata)
library(fable)
#> Loading required package: fabletools
library(tsibble)
library(logger)
library(future)
library(tidyquant)
#> Loading required package: lubridate
#> 
#> Attaching package: 'lubridate'
#> The following object is masked from 'package:tsibble':
#> 
#>     interval
#> The following objects are masked from 'package:base':
#> 
#>     date, intersect, setdiff, union
#> Loading required package: PerformanceAnalytics
#> Loading required package: xts
#> Loading required package: zoo
#> 
#> Attaching package: 'zoo'
#> The following object is masked from 'package:tsibble':
#> 
#>     index
#> The following objects are masked from 'package:base':
#> 
#>     as.Date, as.Date.numeric
#> 
#> Attaching package: 'xts'
#> The following objects are masked from 'package:dplyr':
#> 
#>     first, last
#> 
#> Attaching package: 'PerformanceAnalytics'
#> The following object is masked from 'package:graphics':
#> 
#>     legend
#> Loading required package: quantmod
#> Loading required package: TTR
#> Registered S3 method overwritten by 'quantmod':
#>   method            from
#>   as.zoo.data.frame zoo
#> Version 0.4-0 included new data defaults. See ?getSymbols.
#> == Need to Learn tidyquant? =======================================================================================================================
#> Business Science offers a 1-hour course - Learning Lab #9: Performance Analysis & Portfolio Optimization with tidyquant!
#> </> Learn more at: https://university.business-science.io/p/learning-labs-pro </>
#> 
#> Attaching package: 'tidyquant'
#> The following object is masked from 'package:fable':
#> 
#>     VAR
library(tictoc)

# Set up some variables
my_vars <- NULL
# Value variable
my_vars$unit_column_name <- "adjusted"
# Whether to run in parallel or not
my_vars$run_parallel <- TRUE
# Number of cycles to forecast for
my_vars$num_cycles_to_forecast <- 9
# Weeks to predict for in each cycle
my_vars$weeks_to_predict <- 13
# Number of stock symbols
my_vars$num_stock_symbols <- 200

# Get stock data
stocks <- tq_index("SP500")
#> Getting holdings for SP500
symbols <- stocks %>% pull(symbol)

# Get stock price data
stock_prices  <- tq_get(symbols[1:my_vars$num_stock_symbols], 
                        get = "stock.prices", 
                        from = "2015-01-01")
#> Warning: Problem with `mutate()` input `data..`.
#> i x = 'BRK.B', get = 'stock.prices': Error in getSymbols.yahoo(Symbols = "BRK.B", env = <environment>, verbose = FALSE, : Unable to import "BRK.B".
#> BRK.B download failed after two attempts. Error message:
#> HTTP error 404.
#>  Removing BRK.B.
#> i Input `data..` is `purrr::map(...)`.
#> Warning: x = 'BRK.B', get = 'stock.prices': Error in getSymbols.yahoo(Symbols = "BRK.B", env = <environment>, verbose = FALSE, : Unable to import "BRK.B".
#> BRK.B download failed after two attempts. Error message:
#> HTTP error 404.
#>  Removing BRK.B.

# Convert to tsibble
stock_prices_tsibble <- stock_prices %>%
  as_tsibble(key = c(symbol),
             index = date)

# Prepare the training data, just need weekly data
weekly_stocks_tsibble <- stock_prices_tsibble %>%
  setNames(tolower(names(.))) %>%
  rename(forecast_week = date) %>%
  select(symbol, forecast_week, adjusted) %>%
  mutate(weekday = lubridate::wday(forecast_week)) %>%
  filter(weekday == 6) %>%
  select(-weekday) %>%
  as_tsibble(key = c(symbol),
             index = forecast_week) %>%
  fill_gaps(!!as.name({{my_vars$unit_column_name}}) := 0)

# Get the cycles we want to forecast for
my_vars$cycles_to_forecast <- weekly_stocks_tsibble %>% 
  slice_tail(n = my_vars$num_cycles_to_forecast) %>% 
  pull(forecast_week)


run_my_models <- function(cycles_to_forecast, actuals_column_name, weeks_to_predict, run_parallel, actuals_tsibble, ...) {
  
  # 
  if (run_parallel == TRUE) {
    plan(multiprocess)
  }
  
  # Tibble to hold results
  log_info("Creating holder tibble")
  holder_data_frame <- tibble()
  
  # Fit through cycles
  for (i in 1:length(cycles_to_forecast)) {
    # i <- 1
    
    cur_cycle <- cycles_to_forecast[i]
    
    log_info("Running for cycle {i}/{length(cycles_to_forecast)}: {cur_cycle}")
    
    # Prepare current cycles training data
    cur_training_data <- actuals_tsibble %>%
      filter(forecast_week < cur_cycle)
    
    # Check that there are rows in the training data
    if(nrow(cur_training_data) <= 0) {
      warn("No rows in current cycle training data")
      next()
    }
    
    log_info("Training data: {min(cur_training_data$forecast_week)} - {max(cur_training_data$forecast_week)}")
    tic()
    # Fit models
    cur_fit <- cur_training_data %>%
      model(...)
    
    log_info("Models fitted")
    toc()
    
    # Predict
    predictions <- cur_fit %>% 
      forecast(h = my_vars$weeks_to_predict, 
               point_forecast = list(forecasted_units = mean))
    
    log_info("Predictions generated")
    
    # Colect useful prediction information
    cur_fit_formatted <- cur_fit %>%
      as_tibble() %>%
      # mutate_if(~!is.character(.), print) %>%
      pivot_longer(cols = -c(symbol),
                   names_to = "method",
                   values_to = "method_specifics") %>%
      lazy_dt()
    
    collected_predictions <- predictions %>%
      as_tibble() %>%
      lazy_dt() %>%
      rename(method = .model) %>%
      left_join(cur_fit_formatted, by = c("symbol", "method")) %>%
      mutate(forecast_cycle = cur_cycle) %>%
      select(symbol, forecast_cycle, forecast_week, forecasted_units, method, method_specifics)
    
    log_info("Predictions colected")
    
    holder_data_frame <- holder_data_frame %>%
      bind_rows(as_tibble(collected_predictions))
    
  }
  
  return(holder_data_frame)
}

model_predictions <- run_my_models(cycles_to_forecast = my_vars$cycles_to_forecast,
                                   actuals_column_name = my_vars$unit_column_name,
                                   weeks_to_predict = my_vars$weeks_to_predict,
                                   run_parallel = my_vars$run_parallel,
                                   actuals_tsibble = weekly_stocks_tsibble,
                                   # Model definitions
                                   arima = ARIMA(!!as.name(my_vars$unit_column_name)))
#> INFO [2020-09-03 08:25:35] Creating holder tibble
#> INFO [2020-09-03 08:25:35] Running for cycle 1/9: 2020-07-03
#> INFO [2020-09-03 08:25:35] Training data: 2015-01-02 - 2020-06-26
#> INFO [2020-09-03 08:26:08] Models fitted
#> 33.27 sec elapsed
#> INFO [2020-09-03 08:26:10] Predictions generated
#> INFO [2020-09-03 08:26:10] Predictions colected
#> INFO [2020-09-03 08:26:11] Running for cycle 2/9: 2020-07-10
#> INFO [2020-09-03 08:26:11] Training data: 2015-01-02 - 2020-07-03
#> INFO [2020-09-03 08:26:42] Models fitted
#> 30.15 sec elapsed
#> INFO [2020-09-03 08:26:44] Predictions generated
#> INFO [2020-09-03 08:26:44] Predictions colected
#> INFO [2020-09-03 08:26:44] Running for cycle 3/9: 2020-07-17
#> INFO [2020-09-03 08:26:44] Training data: 2015-01-02 - 2020-07-10
#> INFO [2020-09-03 08:27:35] Models fitted
#> 50.63 sec elapsed
#> INFO [2020-09-03 08:27:37] Predictions generated
#> INFO [2020-09-03 08:27:37] Predictions colected
#> INFO [2020-09-03 08:27:38] Running for cycle 4/9: 2020-07-24
#> INFO [2020-09-03 08:27:38] Training data: 2015-01-02 - 2020-07-17
#> INFO [2020-09-03 08:28:43] Models fitted
#> 64.41 sec elapsed
#> INFO [2020-09-03 08:28:45] Predictions generated
#> INFO [2020-09-03 08:28:45] Predictions colected
#> INFO [2020-09-03 08:28:45] Running for cycle 5/9: 2020-07-31
#> INFO [2020-09-03 08:28:45] Training data: 2015-01-02 - 2020-07-24
#> INFO [2020-09-03 08:30:06] Models fitted
#> 81.08 sec elapsed
#> INFO [2020-09-03 08:30:09] Predictions generated
#> INFO [2020-09-03 08:30:09] Predictions colected
#> INFO [2020-09-03 08:30:09] Running for cycle 6/9: 2020-08-07
#> INFO [2020-09-03 08:30:09] Training data: 2015-01-02 - 2020-07-31
#> INFO [2020-09-03 08:31:55] Models fitted
#> 105.32 sec elapsed
#> INFO [2020-09-03 08:31:57] Predictions generated
#> INFO [2020-09-03 08:31:57] Predictions colected
#> INFO [2020-09-03 08:31:57] Running for cycle 7/9: 2020-08-14
#> INFO [2020-09-03 08:31:57] Training data: 2015-01-02 - 2020-08-07
#> INFO [2020-09-03 08:34:00] Models fitted
#> 123.16 sec elapsed
#> INFO [2020-09-03 08:34:02] Predictions generated
#> INFO [2020-09-03 08:34:02] Predictions colected
#> INFO [2020-09-03 08:34:02] Running for cycle 8/9: 2020-08-21
#> INFO [2020-09-03 08:34:02] Training data: 2015-01-02 - 2020-08-14
#> INFO [2020-09-03 08:36:27] Models fitted
#> 144.39 sec elapsed
#> INFO [2020-09-03 08:36:29] Predictions generated
#> INFO [2020-09-03 08:36:29] Predictions colected
#> INFO [2020-09-03 08:36:29] Running for cycle 9/9: 2020-08-28
#> INFO [2020-09-03 08:36:29] Training data: 2015-01-02 - 2020-08-21
#> INFO [2020-09-03 08:39:06] Models fitted
#> 156.76 sec elapsed
#> INFO [2020-09-03 08:39:08] Predictions generated
#> INFO [2020-09-03 08:39:08] Predictions colected

sessionInfo()

R version 4.0.2 (2020-06-22)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows Server x64 (build 14393)

Matrix products: default

locale:
[1] LC_COLLATE=English_Ireland.1252  LC_CTYPE=English_Ireland.1252    LC_MONETARY=English_Ireland.1252 LC_NUMERIC=C                    
[5] LC_TIME=English_Ireland.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
 [1] ps_1.3.4        digest_0.6.25   crayon_1.3.4    R6_2.4.1        lifecycle_0.2.0 reprex_0.3.0    magrittr_1.5    evaluate_0.14  
 [9] pillar_1.4.4    rlang_0.4.7     rstudioapi_0.11 fs_1.5.0        callr_3.4.3     whisker_0.4     vctrs_0.3.3     ellipsis_0.3.1 
[17] rmarkdown_2.3   tools_4.0.2     processx_3.4.3  xfun_0.16       compiler_4.0.2  pkgconfig_2.0.3 clipr_0.7.0     htmltools_0.5.0
[25] knitr_1.29      tibble_3.0.3   
0

There are 0 best solutions below