Skip to content

This function is a wrapper to fit many different types of linear classification models on a (grouped) tibble.



a data frame, data frame extension (e.g. a tibble), or a lazy data frame (e.g. from dbplyr or dtplyr). The data frame can be grouped.


an object of class "formula": a symbolic description of the model to be fitted.


name-function pairs of models to be estimated. See 'Details'.


type of 'rsample' cross validation procedure to use to determine optimal hyperparameter values. Default is .cv = "none". See 'Details'.


additional settings to pass to the 'rsample' cross validation function.


optional name of column containing sample weights.


optional vector of columns names to ignore. Can be useful when using 'y ~ .' formula syntax.


logical. Should the output of individual cross validation slices be returned or only the final fit. Default is .return_slices=FALSE.


logical. Should optimal hyperparameters be selected for each group or once across all groups. Default is .tune_each_group=TRUE.


logical. Should models be evaluated across all cross validation slices, even if no hyperparameters are tuned. Default is .force_cv=TRUE.


A tidyfit.models frame containing model details for each group.

The 'tidyfit.models' frame consists of 4 different components:

  1. A group of identifying columns (e.g. model name, data groups, grid IDs)

  2. A 'model_object' column, which contains the fitted model.

  3. A nested 'settings' column containing model arguments and hyperparameters

  4. Columns showing errors, warnings and messages (if applicable)

Coefficients, predictions, fitted values or residuals can be accessed using the built-in coef, predict, fitted and resid methods. Note that all coefficients are transformed to ensure comparability across methods.


classify fits all models passed in ... using the m function. The models can be passed as name-function pairs (e.g. ols = m("lm")) or without including a name.

Hyperparameters are tuned automatically using the '.cv' and '.cv_args' arguments, or can be passed to m() (e.g. lasso = m("lasso", lambda = 0.5)). See the individual model functions (?m()) for an overview of hyperparameters.

Cross validation is performed using the 'rsample' package with possible methods including

  • 'initial_split' (simple train-test split)

  • 'initial_time_split' (train-test split with retained order)

  • 'vfold_cv' (aka kfold cross validation)

  • 'loo_cv' (leave-one-out)

  • 'rolling_origin' (generalized time series cross validation, e.g. rolling or expanding windows)

  • 'sliding_window', 'sliding_index', 'sliding_period' (specialized time series splits)

  • 'bootstraps'

  • 'group_vfold_cv', 'group_bootstraps'

See package documentation for 'rsample' for all available methods.

The negative log loss is used to validate performance in the cross validation.

Note that arguments for weights are automatically passed to the functions by setting the '.weights' argument. Weights are also considered during cross validation by calculating weighted versions of the cross validation loss function.

classify can handle both binomial and multinomial response distributions, however not all underlying methods are capable of handling a multinomial response.


Johann Pfitzinger


data <- tidyfit::Factor_Industry_Returns
data <- dplyr::mutate(data, Return = ifelse(Return > 0, 1, 0))
fit <- classify(data, Return ~ ., m("lasso", lambda = c(0.001, 0.1)), .mask = c("Date", "Industry"))

# Print the models frame
tidyr::unnest(fit, settings)
#> # A tibble: 2 × 9
#>   model estimator_fct  `size (MB)` grid_id  model_object weights alpha
#>   <chr> <chr>                <dbl> <chr>    <list>       <list>  <dbl>
#> 1 lasso glmnet::glmnet        2.04 #001|001 <tidyFit>    <NULL>      1
#> 2 lasso glmnet::glmnet        2.04 #001|002 <tidyFit>    <NULL>      1
#> # ℹ 2 more variables: family <chr>, lambda <dbl>

# View coefficients
#> # A tibble: 9 × 5
#> # Groups:   model [1]
#>   model term        estimate grid_id  model_info      
#>   <chr> <chr>          <dbl> <chr>    <list>          
#> 1 lasso (Intercept)   0.310  #001|001 <tibble [1 × 2]>
#> 2 lasso Mkt-RF        0.230  #001|001 <tibble [1 × 2]>
#> 3 lasso (Intercept)   0.0574 #001|002 <tibble [1 × 2]>
#> 4 lasso Mkt-RF        0.574  #001|002 <tibble [1 × 2]>
#> 5 lasso SMB           0.0167 #001|002 <tibble [1 × 2]>
#> 6 lasso HML           0.0347 #001|002 <tibble [1 × 2]>
#> 7 lasso RMW           0.146  #001|002 <tibble [1 × 2]>
#> 8 lasso CMA           0.0963 #001|002 <tibble [1 × 2]>
#> 9 lasso RF            0.295  #001|002 <tibble [1 × 2]>