I am learning the ropes of the new tidymodels framework, so I may be misunderstanding something fundamental.
I provide a self-contained example with a real (in the sense of taken from my work) dataset. Please take it for a given that I need to use all the observations apart from the most recent one as training set and only the most recent observation as the test set (so in this case the test set is an observation only).
However, I get an error I cannot decipher. Any suggestion is appreciated.
Thanks!
library(tidyverse)
library(tidymodels)
df_ini <- structure(list(year = c(1998, 2002, 2004, 2005, 2006, 2007, 2008,
2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018),
capital_n1132g_lag_1 = c(3446.5, 4091.1, 3655.1, 3633.3,
3616.2, 3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 4005.6,
3718.6, 3467.9, 4214.2, 4237.4, 4450.2), capital_n117g_lag_1 = c(4920.9,
7810.6, 8560.3, 8679.9, 8938.9, 9823.8, 10467.1, 11047.1,
11554.3, 11849.9, 13465.4, 13927.5, 15510.2, 15754.4, 16584.7,
17647.1, 18273.8), capital_n11mg_lag_1 = c(16846, 19605,
19381.2, 19433.5, 20051.6, 20569.8, 22646.1, 23674.5, 21200.6,
20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 25019.2, 27608.2,
29790.1), employment_be_lag_1 = c(2834.42, 2839.72, 2765.53,
2731.08, 2709.59, 2708.39, 2774.06, 2795.6, 2703.36, 2668.1,
2705.1, 2731.67, 2727.16, 2725.66, 2735.69, 2750.52, 2782.9
), employment_c_lag_1 = c(2612.76, 2623.69, 2552.89, 2518.57,
2496.98, 2499.54, 2558.88, 2578, 2483.97, 2447.65, 2483.1,
2507.41, 2500.94, 2499.6, 2511.75, 2523.97, 2555.48), employment_j_lag_1 = c(292.93,
389.2, 389.45, 387.53, 384.64, 389.29, 385.77, 392.86, 383.91,
392.18, 410.85, 419.75, 427.59, 438.96, 440.33, 460.84, 473.4
), employment_k_lag_1 = c(505.33, 507.12, 510.25, 504.63,
515.39, 523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 534.98,
524.13, 518.89, 511.57, 505.32, 496.41), employment_mn_lag_1 = c(945.59,
1217.96, 1289.55, 1365.29, 1425.81, 1537.88, 1622.95, 1727.76,
1704.65, 1762.55, 1838.16, 1896.09, 1929.09, 1950.02, 1968.83,
2021.51, 2109.71), employment_oq_lag_1 = c(3065.87, 3191.75,
3280.36, 3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85,
3759.23, 3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59,
4171.72), employment_total_lag_1 = c(14509.58, 15127.99,
15212.11, 15307.28, 15491.61, 15762.92, 16050.92, 16356.53,
16269.97, 16392.87, 16647.79, 16820.66, 16879.06, 17039.6,
17142.13, 17365.32, 17650.21), gdp_b1gq_lag_1 = c(187849.7,
220525, 231862.5, 242348.3, 254075, 267824.4, 283978, 293761.9,
288044.1, 295896.6, 310128.6, 318653.1, 323910.2, 333146.1,
344269.3, 357608, 369341.3), gdp_p3_lag_1 = c(139695.2, 161175.8,
169405.6, 176316.4, 185871.1, 194102, 200944.4, 208857.1,
213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 243860.6,
249404.3, 257166.5, 265900.2), gdp_p61_lag_1 = c(50117.6,
71948.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2, 113368.1,
91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 126109.3,
129183.6, 131524, 140057.8), gdp_p62_lag_1 = c(19441, 26444.4,
28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8, 38781.9,
39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 55885.5,
59584.7), price_index_lag_1 = c(1.2, 2.3, 1.3, 2, 2.1, 1.7,
2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 0.8, 1, 2.2), value_be_lag_1 = c(40533.1,
48207.1, 48673.2, 50737.6, 52955.2, 56872.4, 60864.9, 61029,
56837.8, 58433.6, 61443, 63655.1, 64132.3, 65542.6, 67495.4,
71152.6, 72698.8), value_c_lag_1 = c(33441.8, 40446.6, 40467.4,
42014.6, 44229, 47735.5, 51552.4, 51165.9, 47129.7, 48759.3,
51467.7, 53234.6, 53431.4, 55169, 57458.7, 60962.8, 62196
), value_j_lag_1 = c(5483.7, 7326.1, 7934.1, 7756.1, 8134.2,
8378.8, 8532.3, 8740, 8493.9, 8518.9, 9217.1, 9405.1, 9802.1,
10361.4, 10695.4, 11455.3, 11720.6), value_k_lag_1 = c(9210.6,
9977.3, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7, 13205.2,
12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 13482.9,
13236.4, 13744.1), value_mn_lag_1 = c(10444, 14061.4, 15706.6,
16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2,
24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6
), value_oq_lag_1 = c(29902.7, 34179.2, 36126.8, 37329.6,
38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9,
49381.5, 50261.7, 51624.3, 53715, 55926.4, 57637.1), value_total_lag_1 = c(167323.4,
197076.7, 207247.6, 216098.3, 225888.1, 239076, 253604.6,
262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 297230.1,
307037.7, 318952.7, 329396.1), capital_n1132g_lag_2 = c(3599.2,
3996.9, 3638.4, 3655.1, 3633.3, 3616.2, 3450.7, 3596.8, 3867.2,
3372.5, 3722.9, 3808.5, 4005.6, 3718.6, 3467.9, 4214.2, 4237.4
), capital_n117g_lag_2 = c(4636.2, 7008.5, 8369.6, 8560.3,
8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 11554.3, 11849.9,
13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 17647.1), capital_n11mg_lag_2 = c(17181.5,
19677.8, 18749.6, 19381.2, 19433.5, 20051.6, 20569.8, 22646.1,
23674.5, 21200.6, 20919.6, 23157.7, 23520.7, 24057.7, 23832.8,
25019.2, 27608.2), employment_be_lag_2 = c(2870.33, 2840.19,
2775.22, 2765.53, 2731.08, 2709.59, 2708.39, 2774.06, 2795.6,
2703.36, 2668.1, 2705.1, 2731.67, 2727.16, 2725.66, 2735.69,
2750.52), employment_c_lag_2 = c(2626.2, 2621.08, 2562.53,
2552.89, 2518.57, 2496.98, 2499.54, 2558.88, 2578, 2483.97,
2447.65, 2483.1, 2507.41, 2500.94, 2499.6, 2511.75, 2523.97
), employment_j_lag_2 = c(275.08, 374.56, 400.75, 389.45,
387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 392.18, 410.85,
419.75, 427.59, 438.96, 440.33, 460.84), employment_k_lag_2 = c(500.9,
505.13, 502.42, 510.25, 504.63, 515.39, 523.45, 536.6, 550.14,
546.68, 539.96, 536.58, 534.98, 524.13, 518.89, 511.57, 505.32
), employment_mn_lag_2 = c(904.38, 1143.78, 1248.01, 1289.55,
1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 1704.65, 1762.55,
1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 2021.51), employment_oq_lag_2 = c(3028.85,
3162.77, 3241.36, 3280.36, 3317.09, 3401.65, 3476.63, 3508.01,
3577.75, 3683.85, 3759.23, 3798.35, 3850.17, 3877.24, 3924.06,
4002.74, 4095.59), employment_total_lag_2 = c(14404.29, 15019.87,
15113.52, 15212.11, 15307.28, 15491.61, 15762.92, 16050.92,
16356.53, 16269.97, 16392.87, 16647.79, 16820.66, 16879.06,
17039.6, 17142.13, 17365.32), gdp_b1gq_lag_2 = c(186928.7,
213606.4, 226735.3, 231862.5, 242348.3, 254075, 267824.4,
283978, 293761.9, 288044.1, 295896.6, 310128.6, 318653.1,
323910.2, 333146.1, 344269.3, 357608), gdp_p3_lag_2 = c(140335.8,
156117.3, 164107.8, 169405.6, 176316.4, 185871.1, 194102,
200944.4, 208857.1, 213630.1, 218947.2, 227250.8, 233638.1,
238329.3, 243860.6, 249404.3, 257166.5), gdp_p61_lag_2 = c(44541.4,
67701.6, 74691.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2,
113368.1, 91435.3, 111997.3, 123526.3, 125801.2, 123657.1,
126109.3, 129183.6, 131524), gdp_p62_lag_2 = c(19504.2, 24888.9,
28063.4, 28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8,
38781.9, 39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8,
55885.5), value_be_lag_2 = c(40076.7, 46109.4, 47967.1, 48673.2,
50737.6, 52955.2, 56872.4, 60864.9, 61029, 56837.8, 58433.6,
61443, 63655.1, 64132.3, 65542.6, 67495.4, 71152.6), value_c_lag_2 = c(32955.4,
38908.4, 40192.9, 40467.4, 42014.6, 44229, 47735.5, 51552.4,
51165.9, 47129.7, 48759.3, 51467.7, 53234.6, 53431.4, 55169,
57458.7, 60962.8), value_j_lag_2 = c(5576.8, 6313.9, 7737.1,
7934.1, 7756.1, 8134.2, 8378.8, 8532.3, 8740, 8493.9, 8518.9,
9217.1, 9405.1, 9802.1, 10361.4, 10695.4, 11455.3), value_k_lag_2 = c(9191,
10458, 10225.2, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7,
13205.2, 12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4,
13482.9, 13236.4), value_mn_lag_2 = c(10092, 12942.5, 15074,
15706.6, 16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490,
23255.2, 24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7
), value_oq_lag_2 = c(30224.3, 33251.5, 35065.6, 36126.8,
37329.6, 38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6,
47980.9, 49381.5, 50261.7, 51624.3, 53715, 55926.4), value_total_lag_2 = c(167141.8,
190624.9, 202353.5, 207247.6, 216098.3, 225888.1, 239076,
253604.6, 262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3,
297230.1, 307037.7, 318952.7), berd = c(2146.085, 3130.884,
3556.479, 4207.669, 4448.676, 4845.861, 5232.63, 5092.902,
5520.422, 5692.841, 6540.457, 6778.42, 7324.679, 7498.488,
7824.51, 7888.444, 8461.72)), row.names = c(NA, -17L), class = c("tbl_df",
"tbl", "data.frame"))
set.seed(1234) ## to make the results reproducible
## I need a particular custom split of my dataset: the test set consists of only the most recent observation, whereas all the rest is the training set
## see https://github.com/tidymodels/rsample/issues/158
indices <-
list(analysis = seq(nrow(df_ini)-1),
assessment = nrow(df_ini)
)
df_split <- make_splits(indices, df_ini)
## df_split <- initial_split(df_ini) ## with the default splitting,
## ## the code works
df_train <- training(df_split)
df_test <- testing(df_split)
folded_data <- vfold_cv(df_train,3)
glmnet_recipe <-
recipe(formula = berd ~ ., data = df_train) %>%
update_role(year, new_role = "ID") %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors(), -all_nominal())
glmnet_spec <-
linear_reg(penalty = tune(), mixture = tune()) %>%
set_mode("regression") %>%
set_engine("glmnet")
glmnet_workflow <-
workflow() %>%
add_recipe(glmnet_recipe) %>%
add_model(glmnet_spec)
glmnet_grid <- tidyr::crossing(penalty = 10^seq(-6, -1, length.out = 20), mixture = c(0.05,
0.2, 0.4, 0.6, 0.8, 1))
glmnet_tune <-
tune_grid(glmnet_workflow, resamples = folded_data, grid = glmnet_grid,control = control_grid(save_pred = TRUE) )
print(collect_metrics(glmnet_tune))
#> # A tibble: 240 x 8
#> penalty mixture .metric .estimator mean n std_err .config
#> <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 0.000001 0.05 rmse standard 375. 3 48.9 Model001
#> 2 0.000001 0.05 rsq standard 0.929 3 0.0420 Model001
#> 3 0.00000183 0.05 rmse standard 375. 3 48.9 Model002
#> 4 0.00000183 0.05 rsq standard 0.929 3 0.0420 Model002
#> 5 0.00000336 0.05 rmse standard 375. 3 48.9 Model003
#> 6 0.00000336 0.05 rsq standard 0.929 3 0.0420 Model003
#> 7 0.00000616 0.05 rmse standard 375. 3 48.9 Model004
#> 8 0.00000616 0.05 rsq standard 0.929 3 0.0420 Model004
#> 9 0.0000113 0.05 rmse standard 375. 3 48.9 Model005
#> 10 0.0000113 0.05 rsq standard 0.929 3 0.0420 Model005
#> # … with 230 more rows
print(show_best(glmnet_tune, "rmse"))
#> # A tibble: 5 x 8
#> penalty mixture .metric .estimator mean n std_err .config
#> <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 0.000001 0.05 rmse standard 375. 3 48.9 Model001
#> 2 0.00000183 0.05 rmse standard 375. 3 48.9 Model002
#> 3 0.00000336 0.05 rmse standard 375. 3 48.9 Model003
#> 4 0.00000616 0.05 rmse standard 375. 3 48.9 Model004
#> 5 0.0000113 0.05 rmse standard 375. 3 48.9 Model005
best_net <- select_best(glmnet_tune, "rmse")
final_net <- finalize_workflow(
glmnet_workflow,
best_net
)
final_res_net <- last_fit(final_net, df_split)
#> x : internal: Error in data.frame(..., check.names = FALSE): arguments imply...
#> Warning: All models failed in [fit_resamples()]. See the `.notes` column.
print(final_res_net)
#> Warning: This tuning result has notes. Example notes on model fitting include:
#> internal: Error in data.frame(..., check.names = FALSE): arguments imply differing number of rows: 2, 0
#> # Resampling results
#> # Monte Carlo cross-validation (0.94/0.059) with 1 resamples
#> # A tibble: 1 x 5
#> splits id .metrics .notes .predictions
#> <list> <chr> <list> <list> <list>
#> 1 <split [16/1]> train/test split <NULL> <tibble [1 × 1]> <NULL>
final_fit <- final_res_net %>%
collect_predictions()
Created on 2020-10-15 by the reprex package (v0.3.0.9001)