Tidymodels: Strange error message with custom split of data

87 Views Asked by At

I am learning the ropes of the new tidymodels framework, so I may be misunderstanding something fundamental.

I provide a self-contained example with a real (in the sense of taken from my work) dataset. Please take it for a given that I need to use all the observations apart from the most recent one as training set and only the most recent observation as the test set (so in this case the test set is an observation only).

However, I get an error I cannot decipher. Any suggestion is appreciated.

Thanks!

library(tidyverse) 

library(tidymodels)


df_ini <- structure(list(year = c(1998, 2002, 2004, 2005, 2006, 2007, 2008, 
2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018), 
    capital_n1132g_lag_1 = c(3446.5, 4091.1, 3655.1, 3633.3, 
    3616.2, 3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 4005.6, 
    3718.6, 3467.9, 4214.2, 4237.4, 4450.2), capital_n117g_lag_1 = c(4920.9, 
    7810.6, 8560.3, 8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 
    11554.3, 11849.9, 13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 
    17647.1, 18273.8), capital_n11mg_lag_1 = c(16846, 19605, 
    19381.2, 19433.5, 20051.6, 20569.8, 22646.1, 23674.5, 21200.6, 
    20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 25019.2, 27608.2, 
    29790.1), employment_be_lag_1 = c(2834.42, 2839.72, 2765.53, 
    2731.08, 2709.59, 2708.39, 2774.06, 2795.6, 2703.36, 2668.1, 
    2705.1, 2731.67, 2727.16, 2725.66, 2735.69, 2750.52, 2782.9
    ), employment_c_lag_1 = c(2612.76, 2623.69, 2552.89, 2518.57, 
    2496.98, 2499.54, 2558.88, 2578, 2483.97, 2447.65, 2483.1, 
    2507.41, 2500.94, 2499.6, 2511.75, 2523.97, 2555.48), employment_j_lag_1 = c(292.93, 
    389.2, 389.45, 387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 
    392.18, 410.85, 419.75, 427.59, 438.96, 440.33, 460.84, 473.4
    ), employment_k_lag_1 = c(505.33, 507.12, 510.25, 504.63, 
    515.39, 523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 534.98, 
    524.13, 518.89, 511.57, 505.32, 496.41), employment_mn_lag_1 = c(945.59, 
    1217.96, 1289.55, 1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 
    1704.65, 1762.55, 1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 
    2021.51, 2109.71), employment_oq_lag_1 = c(3065.87, 3191.75, 
    3280.36, 3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 
    3759.23, 3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 
    4171.72), employment_total_lag_1 = c(14509.58, 15127.99, 
    15212.11, 15307.28, 15491.61, 15762.92, 16050.92, 16356.53, 
    16269.97, 16392.87, 16647.79, 16820.66, 16879.06, 17039.6, 
    17142.13, 17365.32, 17650.21), gdp_b1gq_lag_1 = c(187849.7, 
    220525, 231862.5, 242348.3, 254075, 267824.4, 283978, 293761.9, 
    288044.1, 295896.6, 310128.6, 318653.1, 323910.2, 333146.1, 
    344269.3, 357608, 369341.3), gdp_p3_lag_1 = c(139695.2, 161175.8, 
    169405.6, 176316.4, 185871.1, 194102, 200944.4, 208857.1, 
    213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 243860.6, 
    249404.3, 257166.5, 265900.2), gdp_p61_lag_1 = c(50117.6, 
    71948.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2, 113368.1, 
    91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 126109.3, 
    129183.6, 131524, 140057.8), gdp_p62_lag_1 = c(19441, 26444.4, 
    28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8, 38781.9, 
    39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 55885.5, 
    59584.7), price_index_lag_1 = c(1.2, 2.3, 1.3, 2, 2.1, 1.7, 
    2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 0.8, 1, 2.2), value_be_lag_1 = c(40533.1, 
    48207.1, 48673.2, 50737.6, 52955.2, 56872.4, 60864.9, 61029, 
    56837.8, 58433.6, 61443, 63655.1, 64132.3, 65542.6, 67495.4, 
    71152.6, 72698.8), value_c_lag_1 = c(33441.8, 40446.6, 40467.4, 
    42014.6, 44229, 47735.5, 51552.4, 51165.9, 47129.7, 48759.3, 
    51467.7, 53234.6, 53431.4, 55169, 57458.7, 60962.8, 62196
    ), value_j_lag_1 = c(5483.7, 7326.1, 7934.1, 7756.1, 8134.2, 
    8378.8, 8532.3, 8740, 8493.9, 8518.9, 9217.1, 9405.1, 9802.1, 
    10361.4, 10695.4, 11455.3, 11720.6), value_k_lag_1 = c(9210.6, 
    9977.3, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7, 13205.2, 
    12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 13482.9, 
    13236.4, 13744.1), value_mn_lag_1 = c(10444, 14061.4, 15706.6, 
    16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2, 
    24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6
    ), value_oq_lag_1 = c(29902.7, 34179.2, 36126.8, 37329.6, 
    38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 
    49381.5, 50261.7, 51624.3, 53715, 55926.4, 57637.1), value_total_lag_1 = c(167323.4, 
    197076.7, 207247.6, 216098.3, 225888.1, 239076, 253604.6, 
    262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 297230.1, 
    307037.7, 318952.7, 329396.1), capital_n1132g_lag_2 = c(3599.2, 
    3996.9, 3638.4, 3655.1, 3633.3, 3616.2, 3450.7, 3596.8, 3867.2, 
    3372.5, 3722.9, 3808.5, 4005.6, 3718.6, 3467.9, 4214.2, 4237.4
    ), capital_n117g_lag_2 = c(4636.2, 7008.5, 8369.6, 8560.3, 
    8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 11554.3, 11849.9, 
    13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 17647.1), capital_n11mg_lag_2 = c(17181.5, 
    19677.8, 18749.6, 19381.2, 19433.5, 20051.6, 20569.8, 22646.1, 
    23674.5, 21200.6, 20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 
    25019.2, 27608.2), employment_be_lag_2 = c(2870.33, 2840.19, 
    2775.22, 2765.53, 2731.08, 2709.59, 2708.39, 2774.06, 2795.6, 
    2703.36, 2668.1, 2705.1, 2731.67, 2727.16, 2725.66, 2735.69, 
    2750.52), employment_c_lag_2 = c(2626.2, 2621.08, 2562.53, 
    2552.89, 2518.57, 2496.98, 2499.54, 2558.88, 2578, 2483.97, 
    2447.65, 2483.1, 2507.41, 2500.94, 2499.6, 2511.75, 2523.97
    ), employment_j_lag_2 = c(275.08, 374.56, 400.75, 389.45, 
    387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 392.18, 410.85, 
    419.75, 427.59, 438.96, 440.33, 460.84), employment_k_lag_2 = c(500.9, 
    505.13, 502.42, 510.25, 504.63, 515.39, 523.45, 536.6, 550.14, 
    546.68, 539.96, 536.58, 534.98, 524.13, 518.89, 511.57, 505.32
    ), employment_mn_lag_2 = c(904.38, 1143.78, 1248.01, 1289.55, 
    1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 1704.65, 1762.55, 
    1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 2021.51), employment_oq_lag_2 = c(3028.85, 
    3162.77, 3241.36, 3280.36, 3317.09, 3401.65, 3476.63, 3508.01, 
    3577.75, 3683.85, 3759.23, 3798.35, 3850.17, 3877.24, 3924.06, 
    4002.74, 4095.59), employment_total_lag_2 = c(14404.29, 15019.87, 
    15113.52, 15212.11, 15307.28, 15491.61, 15762.92, 16050.92, 
    16356.53, 16269.97, 16392.87, 16647.79, 16820.66, 16879.06, 
    17039.6, 17142.13, 17365.32), gdp_b1gq_lag_2 = c(186928.7, 
    213606.4, 226735.3, 231862.5, 242348.3, 254075, 267824.4, 
    283978, 293761.9, 288044.1, 295896.6, 310128.6, 318653.1, 
    323910.2, 333146.1, 344269.3, 357608), gdp_p3_lag_2 = c(140335.8, 
    156117.3, 164107.8, 169405.6, 176316.4, 185871.1, 194102, 
    200944.4, 208857.1, 213630.1, 218947.2, 227250.8, 233638.1, 
    238329.3, 243860.6, 249404.3, 257166.5), gdp_p61_lag_2 = c(44541.4, 
    67701.6, 74691.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2, 
    113368.1, 91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 
    126109.3, 129183.6, 131524), gdp_p62_lag_2 = c(19504.2, 24888.9, 
    28063.4, 28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8, 
    38781.9, 39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 
    55885.5), value_be_lag_2 = c(40076.7, 46109.4, 47967.1, 48673.2, 
    50737.6, 52955.2, 56872.4, 60864.9, 61029, 56837.8, 58433.6, 
    61443, 63655.1, 64132.3, 65542.6, 67495.4, 71152.6), value_c_lag_2 = c(32955.4, 
    38908.4, 40192.9, 40467.4, 42014.6, 44229, 47735.5, 51552.4, 
    51165.9, 47129.7, 48759.3, 51467.7, 53234.6, 53431.4, 55169, 
    57458.7, 60962.8), value_j_lag_2 = c(5576.8, 6313.9, 7737.1, 
    7934.1, 7756.1, 8134.2, 8378.8, 8532.3, 8740, 8493.9, 8518.9, 
    9217.1, 9405.1, 9802.1, 10361.4, 10695.4, 11455.3), value_k_lag_2 = c(9191, 
    10458, 10225.2, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7, 
    13205.2, 12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 
    13482.9, 13236.4), value_mn_lag_2 = c(10092, 12942.5, 15074, 
    15706.6, 16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 
    23255.2, 24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7
    ), value_oq_lag_2 = c(30224.3, 33251.5, 35065.6, 36126.8, 
    37329.6, 38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 
    47980.9, 49381.5, 50261.7, 51624.3, 53715, 55926.4), value_total_lag_2 = c(167141.8, 
    190624.9, 202353.5, 207247.6, 216098.3, 225888.1, 239076, 
    253604.6, 262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 
    297230.1, 307037.7, 318952.7), berd = c(2146.085, 3130.884, 
    3556.479, 4207.669, 4448.676, 4845.861, 5232.63, 5092.902, 
    5520.422, 5692.841, 6540.457, 6778.42, 7324.679, 7498.488, 
    7824.51, 7888.444, 8461.72)), row.names = c(NA, -17L), class = c("tbl_df", 
"tbl", "data.frame"))





set.seed(1234)  ## to make the results reproducible






## I need a particular custom split of my dataset: the test set consists of only the most recent observation, whereas all the rest is the training set

## see https://github.com/tidymodels/rsample/issues/158


indices <-
  list(analysis   = seq(nrow(df_ini)-1), 
       assessment = nrow(df_ini)
       )

df_split <- make_splits(indices, df_ini)


## df_split <- initial_split(df_ini) ## with the default splitting,
## ## the code works

df_train <- training(df_split)
df_test <- testing(df_split)

folded_data <- vfold_cv(df_train,3)



glmnet_recipe <- 
    recipe(formula = berd ~ ., data = df_train) %>%
    update_role(year, new_role = "ID") %>%
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors(), -all_nominal()) 

glmnet_spec <- 
  linear_reg(penalty = tune(), mixture = tune()) %>% 
  set_mode("regression") %>% 
  set_engine("glmnet") 

glmnet_workflow <- 
  workflow() %>% 
  add_recipe(glmnet_recipe) %>% 
  add_model(glmnet_spec) 




glmnet_grid <- tidyr::crossing(penalty = 10^seq(-6, -1, length.out = 20), mixture = c(0.05, 
    0.2, 0.4, 0.6, 0.8, 1)) 

glmnet_tune <- 
  tune_grid(glmnet_workflow, resamples = folded_data, grid = glmnet_grid,control = control_grid(save_pred = TRUE) ) 

print(collect_metrics(glmnet_tune))
#> # A tibble: 240 x 8
#>       penalty mixture .metric .estimator    mean     n std_err .config 
#>         <dbl>   <dbl> <chr>   <chr>        <dbl> <int>   <dbl> <chr>   
#>  1 0.000001      0.05 rmse    standard   375.        3 48.9    Model001
#>  2 0.000001      0.05 rsq     standard     0.929     3  0.0420 Model001
#>  3 0.00000183    0.05 rmse    standard   375.        3 48.9    Model002
#>  4 0.00000183    0.05 rsq     standard     0.929     3  0.0420 Model002
#>  5 0.00000336    0.05 rmse    standard   375.        3 48.9    Model003
#>  6 0.00000336    0.05 rsq     standard     0.929     3  0.0420 Model003
#>  7 0.00000616    0.05 rmse    standard   375.        3 48.9    Model004
#>  8 0.00000616    0.05 rsq     standard     0.929     3  0.0420 Model004
#>  9 0.0000113     0.05 rmse    standard   375.        3 48.9    Model005
#> 10 0.0000113     0.05 rsq     standard     0.929     3  0.0420 Model005
#> # … with 230 more rows

print(show_best(glmnet_tune, "rmse"))
#> # A tibble: 5 x 8
#>      penalty mixture .metric .estimator  mean     n std_err .config 
#>        <dbl>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>   
#> 1 0.000001      0.05 rmse    standard    375.     3    48.9 Model001
#> 2 0.00000183    0.05 rmse    standard    375.     3    48.9 Model002
#> 3 0.00000336    0.05 rmse    standard    375.     3    48.9 Model003
#> 4 0.00000616    0.05 rmse    standard    375.     3    48.9 Model004
#> 5 0.0000113     0.05 rmse    standard    375.     3    48.9 Model005

best_net <- select_best(glmnet_tune, "rmse")


final_net <- finalize_workflow(
  glmnet_workflow,
  best_net
)


final_res_net <- last_fit(final_net, df_split)
#> x : internal: Error in data.frame(..., check.names = FALSE): arguments imply...
#> Warning: All models failed in [fit_resamples()]. See the `.notes` column.


print(final_res_net)
#> Warning: This tuning result has notes. Example notes on model fitting include:
#> internal: Error in data.frame(..., check.names = FALSE): arguments imply differing number of rows: 2, 0
#> # Resampling results
#> # Monte Carlo cross-validation (0.94/0.059) with 1 resamples  
#> # A tibble: 1 x 5
#>   splits         id               .metrics .notes           .predictions
#>   <list>         <chr>            <list>   <list>           <list>      
#> 1 <split [16/1]> train/test split <NULL>   <tibble [1 × 1]> <NULL>

final_fit <- final_res_net %>%
    collect_predictions()

Created on 2020-10-15 by the reprex package (v0.3.0.9001)

0

There are 0 best solutions below