Error with parsnip fit when I have added a class to a model specification

78 Views Asked by At

I have built a function that makes a model spec. I wanted to try my hand at adding a class to the object so I could build generic workflow and fitting etc functions. Here is the function:

internal_make_spec_tbl <- function(.model_tbl){

  # Tidyeval ----
  model_tbl <- .model_tbl

  # Checks ----
  if (!inherits(model_tbl, "tidyaml_base_tbl")){
    rlang::abort(
      message = "The model tibble must come from the make base tbl function.",
      use_cli_format = TRUE
    )
  }

  # Manipulation
  model_factor_tbl <- model_tbl |>
    dplyr::mutate(.model_id = dplyr::row_number() |>
                    forcats::as_factor()) |>
    dplyr::select(.model_id, dplyr::everything())

  # Make a group split object list
  models_list <- model_factor_tbl |>
    dplyr::group_split(.model_id)

  # Make the Workflow Object using purrr imap
  model_spec <- models_list |>
    purrr::imap(
      .f = function(obj, id){

        # Pull the model column and then pluck the model
        pe <- obj |> dplyr::pull(2) |> purrr::pluck(1)
        pm <- obj |> dplyr::pull(3) |> purrr::pluck(1)
        pf <- obj |> dplyr::pull(4) |> purrr::pluck(1)

        ret <- match.fun(pf)(mode = pm, engine = pe)

        # Add parsnip engine and fns as class
        class(ret) <- c(
          class(ret),
          paste0(base::tolower(pe), "_", base::tolower(pf))
        )

        # Return the result
        attributes(ret)$.tidyaml_mod_class <- paste0(base::tolower(pe), "_", base::tolower(pf))
        return(ret)
      }
    )

  # Return
  # Make sure to return as a tibble
  model_spec_ret <- model_factor_tbl |>
    dplyr::mutate(model_spec = model_spec)  |>
    dplyr::mutate(.model_id = as.integer(.model_id))

  return(model_spec_ret)
}

When I add a class to the model in the last position then fitting the workflow fails with some weird to me grep() length > 1 error.

Any ideas on why? I wanted to do this because for models like gee from multilevelmod it requires that the workflow be built differently as per the parsnip documentation https://parsnip.tidymodels.org/reference/details_logistic_reg_gee.html#other-details so I wanted to try my hand at method dispatch like internal_make_wflw_tbl.gee_linear_reg or something like that.

Here is an example, first how it is working and then how it fails.

library(tidyAML)
library(dplyr)
library(recipes)

# How it currently works
mod_spec_tbl <- fast_regression_parsnip_spec_tbl(
 .parsnip_eng = c("lm","glm","gee"),
 .parsnip_fns = "linear_reg"
)

rec_obj <- recipe(mpg ~ ., data = mtcars)
splits_obj <- create_splits(mtcars, "initial_split")

mod_tbl <- mod_spec_tbl |>
 mutate(wflw = internal_make_wflw(mod_spec_tbl, rec_obj))
Error in `.f()`:
! parsnip could not locate an implementation for `linear_reg` regression model specifications
  using the `gee` engine.
ℹ The parsnip extension package multilevelmod implements support for this specification.
ℹ Please install (if needed) and load to continue.


internal_make_fitted_wflw(mod_tbl, splits_obj)
Error in UseMethod("fit"): no applicable method for 'fit' applied to an object of class "NULL"

[[1]]
══ Workflow [trained] ═════════════════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ───────────────────────────────────────────────────────────────────────────────────────
0 Recipe Steps

── Model ──────────────────────────────────────────────────────────────────────────────────────────────

Call:
stats::lm(formula = ..y ~ ., data = data)

Coefficients:
(Intercept)          cyl         disp           hp         drat           wt         qsec  
   -9.56517      0.29867      0.03116     -0.01767      2.84035     -3.41110      0.86809  
         vs           am         gear         carb  
    2.58991      4.01435      2.23389     -0.84017  


[[2]]
NULL

[[3]]
══ Workflow [trained] ═════════════════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ───────────────────────────────────────────────────────────────────────────────────────
0 Recipe Steps

── Model ──────────────────────────────────────────────────────────────────────────────────────────────

Call:  stats::glm(formula = ..y ~ ., family = stats::gaussian, data = data)

Coefficients:
(Intercept)          cyl         disp           hp         drat           wt         qsec  
   -9.56517      0.29867      0.03116     -0.01767      2.84035     -3.41110      0.86809  
         vs           am         gear         carb  
    2.58991      4.01435      2.23389     -0.84017  

Degrees of Freedom: 23 Total (i.e. Null);  13 Residual
Null Deviance:      972.7 
Residual Deviance: 83.84    AIC: 122.1

The above error is fine since I have not loaded multilevelmod into my session so it fails as I expect it to.

Now with the failure when a class is added to ret in the above function:

mod_tbl <- make_regression_base_tbl()
mod_tbl <- mod_tbl |>
  filter(
    .parsnip_engine %in% c("lm", "glm", "gee") &
      .parsnip_fns == "linear_reg"
  )

class(mod_tbl) <- c("tidyaml_mod_spec_tbl", class(mod_tbl))

mod_spec_tbl <- internal_make_spec_tblv2(mod_tbl)
mod_wflw_tbl <- mod_spec_tbl |>
  mutate(wflw = internal_make_wflw(mod_spec_tbl, rec_obj))
Error in `.f()`:
! parsnip could not locate an implementation for `linear_reg` regression model specifications
  using the `gee` engine.
ℹ The parsnip extension package multilevelmod implements support for this specification.
ℹ Please install (if needed) and load to continue.


internal_make_fitted_wflw(mod_wflw_tbl, splits_obj)
Error in `rlang::env_get()`:
! `nm` must be a string.

Error in UseMethod("fit"): no applicable method for 'fit' applied to an object of class "NULL"

Error in `rlang::env_get()`:
! `nm` must be a string.

[[1]]
NULL

[[2]]
NULL

[[3]]
NULL

Warning messages:
1: In grep(model, env_obj, value = TRUE) :
  argument 'pattern' has length > 1 and only the first element will be used
2: In grep(model, env_obj, value = TRUE) :
  argument 'pattern' has length > 1 and only the first element will be used
0

There are 0 best solutions below