I have built a function that makes a model spec. I wanted to try my hand at adding a class to the object so I could build generic workflow and fitting etc functions. Here is the function:
internal_make_spec_tbl <- function(.model_tbl){
# Tidyeval ----
model_tbl <- .model_tbl
# Checks ----
if (!inherits(model_tbl, "tidyaml_base_tbl")){
rlang::abort(
message = "The model tibble must come from the make base tbl function.",
use_cli_format = TRUE
)
}
# Manipulation
model_factor_tbl <- model_tbl |>
dplyr::mutate(.model_id = dplyr::row_number() |>
forcats::as_factor()) |>
dplyr::select(.model_id, dplyr::everything())
# Make a group split object list
models_list <- model_factor_tbl |>
dplyr::group_split(.model_id)
# Make the Workflow Object using purrr imap
model_spec <- models_list |>
purrr::imap(
.f = function(obj, id){
# Pull the model column and then pluck the model
pe <- obj |> dplyr::pull(2) |> purrr::pluck(1)
pm <- obj |> dplyr::pull(3) |> purrr::pluck(1)
pf <- obj |> dplyr::pull(4) |> purrr::pluck(1)
ret <- match.fun(pf)(mode = pm, engine = pe)
# Add parsnip engine and fns as class
class(ret) <- c(
class(ret),
paste0(base::tolower(pe), "_", base::tolower(pf))
)
# Return the result
attributes(ret)$.tidyaml_mod_class <- paste0(base::tolower(pe), "_", base::tolower(pf))
return(ret)
}
)
# Return
# Make sure to return as a tibble
model_spec_ret <- model_factor_tbl |>
dplyr::mutate(model_spec = model_spec) |>
dplyr::mutate(.model_id = as.integer(.model_id))
return(model_spec_ret)
}
When I add a class to the model in the last position then fitting the workflow fails with some weird to me grep() length > 1 error.
Any ideas on why? I wanted to do this because for models like gee from multilevelmod it requires that the workflow be built differently as per the parsnip documentation https://parsnip.tidymodels.org/reference/details_logistic_reg_gee.html#other-details so I wanted to try my hand at method dispatch like internal_make_wflw_tbl.gee_linear_reg or something like that.
Here is an example, first how it is working and then how it fails.
library(tidyAML)
library(dplyr)
library(recipes)
# How it currently works
mod_spec_tbl <- fast_regression_parsnip_spec_tbl(
.parsnip_eng = c("lm","glm","gee"),
.parsnip_fns = "linear_reg"
)
rec_obj <- recipe(mpg ~ ., data = mtcars)
splits_obj <- create_splits(mtcars, "initial_split")
mod_tbl <- mod_spec_tbl |>
mutate(wflw = internal_make_wflw(mod_spec_tbl, rec_obj))
Error in `.f()`:
! parsnip could not locate an implementation for `linear_reg` regression model specifications
using the `gee` engine.
ℹ The parsnip extension package multilevelmod implements support for this specification.
ℹ Please install (if needed) and load to continue.
internal_make_fitted_wflw(mod_tbl, splits_obj)
Error in UseMethod("fit"): no applicable method for 'fit' applied to an object of class "NULL"
[[1]]
══ Workflow [trained] ═════════════════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()
── Preprocessor ───────────────────────────────────────────────────────────────────────────────────────
0 Recipe Steps
── Model ──────────────────────────────────────────────────────────────────────────────────────────────
Call:
stats::lm(formula = ..y ~ ., data = data)
Coefficients:
(Intercept) cyl disp hp drat wt qsec
-9.56517 0.29867 0.03116 -0.01767 2.84035 -3.41110 0.86809
vs am gear carb
2.58991 4.01435 2.23389 -0.84017
[[2]]
NULL
[[3]]
══ Workflow [trained] ═════════════════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()
── Preprocessor ───────────────────────────────────────────────────────────────────────────────────────
0 Recipe Steps
── Model ──────────────────────────────────────────────────────────────────────────────────────────────
Call: stats::glm(formula = ..y ~ ., family = stats::gaussian, data = data)
Coefficients:
(Intercept) cyl disp hp drat wt qsec
-9.56517 0.29867 0.03116 -0.01767 2.84035 -3.41110 0.86809
vs am gear carb
2.58991 4.01435 2.23389 -0.84017
Degrees of Freedom: 23 Total (i.e. Null); 13 Residual
Null Deviance: 972.7
Residual Deviance: 83.84 AIC: 122.1
The above error is fine since I have not loaded multilevelmod into my session so it fails as I expect it to.
Now with the failure when a class is added to ret in the above function:
mod_tbl <- make_regression_base_tbl()
mod_tbl <- mod_tbl |>
filter(
.parsnip_engine %in% c("lm", "glm", "gee") &
.parsnip_fns == "linear_reg"
)
class(mod_tbl) <- c("tidyaml_mod_spec_tbl", class(mod_tbl))
mod_spec_tbl <- internal_make_spec_tblv2(mod_tbl)
mod_wflw_tbl <- mod_spec_tbl |>
mutate(wflw = internal_make_wflw(mod_spec_tbl, rec_obj))
Error in `.f()`:
! parsnip could not locate an implementation for `linear_reg` regression model specifications
using the `gee` engine.
ℹ The parsnip extension package multilevelmod implements support for this specification.
ℹ Please install (if needed) and load to continue.
internal_make_fitted_wflw(mod_wflw_tbl, splits_obj)
Error in `rlang::env_get()`:
! `nm` must be a string.
Error in UseMethod("fit"): no applicable method for 'fit' applied to an object of class "NULL"
Error in `rlang::env_get()`:
! `nm` must be a string.
[[1]]
NULL
[[2]]
NULL
[[3]]
NULL
Warning messages:
1: In grep(model, env_obj, value = TRUE) :
argument 'pattern' has length > 1 and only the first element will be used
2: In grep(model, env_obj, value = TRUE) :
argument 'pattern' has length > 1 and only the first element will be used