The goal of bonsai is to provide bindings for additional tree-based model engines for use with the {parsnip} package.
If you’re not familiar with parsnip, you can read more about the package on it’s website.
To get started, load bonsai with:
library(bonsai)
#> Loading required package: parsnip
To illustrate how to use the package, we’ll fit some models to a dataset containing measurements on 3 different species of penguins. Loading in that data and checking it out:
library(modeldata)
data(penguins)
str(penguins)
#> tibble [344 × 7] (S3: tbl_df/tbl/data.frame)
#> $ species : Factor w/ 3 levels "Adelie","Chinstrap",..: 1 1 1 1 1 1 1 1 1 1 ...
#> $ island : Factor w/ 3 levels "Biscoe","Dream",..: 3 3 3 3 3 3 3 3 3 3 ...
#> $ bill_length_mm : num [1:344] 39.1 39.5 40.3 NA 36.7 39.3 38.9 39.2 34.1 42 ...
#> $ bill_depth_mm : num [1:344] 18.7 17.4 18 NA 19.3 20.6 17.8 19.6 18.1 20.2 ...
#> $ flipper_length_mm: int [1:344] 181 186 195 NA 193 190 181 195 193 190 ...
#> $ body_mass_g : int [1:344] 3750 3800 3250 NA 3450 3650 3625 4675 3475 4250 ...
#> $ sex : Factor w/ 2 levels "female","male": 2 1 1 NA 1 2 1 2 NA NA ...
Specifically, making use of our knowledge of which island that they
live on and measurements on their flipper length, we will predict their
species using a decision tree. We’ll first do so using the engine
"rpart"
, which is supported with parsnip alone:
# set seed for reproducibility
set.seed(1)
# specify and fit model
<-
dt_mod decision_tree() %>%
set_engine(engine = "rpart") %>%
set_mode(mode = "classification") %>%
fit(
formula = species ~ flipper_length_mm + island,
data = penguins
)
dt_mod#> parsnip model object
#>
#> n= 344
#>
#> node), split, n, loss, yval, (yprob)
#> * denotes terminal node
#>
#> 1) root 344 192 Adelie (0.441860465 0.197674419 0.360465116)
#> 2) flipper_length_mm< 206.5 214 64 Adelie (0.700934579 0.294392523 0.004672897)
#> 4) island=Biscoe,Torgersen 96 1 Adelie (0.989583333 0.000000000 0.010416667) *
#> 5) island=Dream 118 55 Chinstrap (0.466101695 0.533898305 0.000000000)
#> 10) flipper_length_mm< 192.5 59 20 Adelie (0.661016949 0.338983051 0.000000000) *
#> 11) flipper_length_mm>=192.5 59 16 Chinstrap (0.271186441 0.728813559 0.000000000) *
#> 3) flipper_length_mm>=206.5 130 7 Gentoo (0.015384615 0.038461538 0.946153846)
#> 6) island=Dream,Torgersen 7 2 Chinstrap (0.285714286 0.714285714 0.000000000) *
#> 7) island=Biscoe 123 0 Gentoo (0.000000000 0.000000000 1.000000000) *
From this output, we can see that the model generally first looks to
island
to determine species, and then makes use of a mix of
flipper length and island to ultimately make a species prediction.
A benefit of using parsnip and bonsai is that, to use a different
implementation of decision trees, we simply change the engine argument
to set_engine
; all other elements of the interface stay the
same. For instance, using "partykit"
—which implements a
type of decision tree called a conditional inference tree—as
our backend instead:
decision_tree() %>%
set_engine(engine = "partykit") %>%
set_mode(mode = "classification") %>%
fit(
formula = species ~ flipper_length_mm + island,
data = penguins
)#> parsnip model object
#>
#>
#> Model formula:
#> species ~ flipper_length_mm + island
#>
#> Fitted party:
#> [1] root
#> | [2] island in Biscoe
#> | | [3] flipper_length_mm <= 203
#> | | | [4] flipper_length_mm <= 196: Adelie (n = 38, err = 0.0%)
#> | | | [5] flipper_length_mm > 196: Adelie (n = 7, err = 14.3%)
#> | | [6] flipper_length_mm > 203: Gentoo (n = 123, err = 0.0%)
#> | [7] island in Dream, Torgersen
#> | | [8] island in Dream
#> | | | [9] flipper_length_mm <= 192: Adelie (n = 59, err = 33.9%)
#> | | | [10] flipper_length_mm > 192: Chinstrap (n = 65, err = 26.2%)
#> | | [11] island in Torgersen: Adelie (n = 52, err = 0.0%)
#>
#> Number of inner nodes: 5
#> Number of terminal nodes: 6
This model, unlike the first, relies on recursive conditional
inference to generate its splits. As such, we can see it generates
slightly different results. Read more about this implementation of
decision trees in ?details_decision_tree_partykit
.
One generalization of a decision tree is a random forest, which fits a large number of decision trees, each independently of the others. The fitted random forest model combines predictions from the individual decision trees to generate its predictions.
bonsai introduces support for random forests using the
partykit
engine, which implements an algorithm called a
conditional random forest. Conditional random forests are a
type of random forest that uses conditional inference trees (like the
one we fit above!) for its constituent decision trees.
To fit a conditional random forest with partykit, our code looks
pretty similar to that which we we needed to fit a conditional inference
tree. Just switch out decision_tree()
with
rand_forest()
and remember to keep the engine set as
"partykit"
:
<-
rf_mod rand_forest() %>%
set_engine(engine = "partykit") %>%
set_mode(mode = "classification") %>%
fit(
formula = species ~ flipper_length_mm + island,
data = penguins
)
Read more about this implementation of random forests in
?details_rand_forest_partykit
.
Another generalization of a decision tree is a series of decision
trees where each tree depends on the results of previous
trees—this is called a boosted tree. bonsai implements an
additional parsnip engine for this model type called
lightgbm
. To make use of it, start out with a
boost_tree
model spec and set
engine = "lightgbm"
:
<-
bt_mod boost_tree() %>%
set_engine(engine = "lightgbm") %>%
set_mode(mode = "classification") %>%
fit(
formula = species ~ flipper_length_mm + island,
data = penguins
)
bt_mod#> parsnip model object
#>
#> <lgb.Booster>
#> Public:
#> add_valid: function (data, name)
#> best_iter: -1
#> best_score: NA
#> current_iter: function ()
#> dump_model: function (num_iteration = NULL, feature_importance_type = 0L)
#> eval: function (data, name, feval = NULL)
#> eval_train: function (feval = NULL)
#> eval_valid: function (feval = NULL)
#> finalize: function ()
#> initialize: function (params = list(), train_set = NULL, modelfile = NULL,
#> lower_bound: function ()
#> params: list
#> predict: function (data, start_iteration = NULL, num_iteration = NULL,
#> raw: NA
#> record_evals: list
#> reset_parameter: function (params, ...)
#> rollback_one_iter: function ()
#> save: function ()
#> save_model: function (filename, num_iteration = NULL, feature_importance_type = 0L)
#> save_model_to_string: function (num_iteration = NULL, feature_importance_type = 0L)
#> set_train_data_name: function (name)
#> to_predictor: function ()
#> update: function (train_set = NULL, fobj = NULL)
#> upper_bound: function ()
#> Private:
#> eval_names: NULL
#> get_eval_info: function ()
#> handle: lgb.Booster.handle
#> higher_better_inner_eval: NULL
#> init_predictor: NULL
#> inner_eval: function (data_name, data_idx, feval = NULL)
#> inner_predict: function (idx)
#> is_predicted_cur_iter: list
#> name_train_set: training
#> name_valid_sets: list
#> num_class: 3
#> num_dataset: 1
#> predict_buffer: list
#> set_objective_to_none: FALSE
#> train_set: lgb.Dataset, R6
#> train_set_version: 1
#> valid_sets: list
Read more about this implementation of boosted trees in
?details_boost_tree_lightgbm
.
Each of these model specs and engines have several arguments and tuning parameters that affect user experience and results greatly. We recommend reading about each of these parameters and tuning them when you find them relevant for your modeling use case.