PD and ICE curves with ORSF

Partial dependence (PD)

Partial dependence (PD) shows the expected prediction from a model as a function of a single predictor or multiple predictors. The expectation is marginalized over the values of all other predictors, giving something like a multivariable adjusted estimate of the model’s prediction.

Begin by fitting an ORSF ensemble. Set a prediction horizon of 5 years when we fit the ensemble so that any aorsf function that we pass this ensemble to will assume we want to compute predictions at 5 years.


library(aorsf)

pred_horizon <- 365.25 * 5

set.seed(329730)

index_train <- sample(nrow(pbc_orsf), 150) 

pbc_orsf_train <- pbc_orsf[index_train, ]
pbc_orsf_test <- pbc_orsf[-index_train, ]

fit <- orsf(data = pbc_orsf_train, 
            formula = Surv(time, status) ~ . - id,
            oobag_pred_horizon = pred_horizon)

fit
#> ---------- Oblique random survival forest
#> 
#>      Linear combinations: Accelerated
#>           N observations: 150
#>                 N events: 52
#>                  N trees: 500
#>       N predictors total: 17
#>    N predictors per node: 5
#>  Average leaves per tree: 12
#> Min observations in leaf: 5
#>       Min events in leaf: 1
#>           OOB stat value: 0.83
#>            OOB stat type: Harrell's C-statistic
#>      Variable importance: anova
#> 
#> -----------------------------------------

Three ways to compute PD

You can compute PD three ways with aorsf:

Let’s re-fit our ORSF to all available data before proceeding to the next sections.


set.seed(329730)

fit <- orsf(pbc_orsf, 
            Surv(time, status) ~ . -id,
            oobag_pred_horizon = pred_horizon)

One variable, one horizon

Computing PD for a single variable is straightforward:


pd_sex <- orsf_pd_oob(fit, pred_spec = list(sex = c("m", "f")))

pd_sex
#>    pred_horizon    sex      mean        lwr      medn       upr
#>           <num> <fctr>     <num>      <num>     <num>     <num>
#> 1:      1826.25      m 0.3527805 0.03974647 0.2414356 0.9444124
#> 2:      1826.25      f 0.2932127 0.01115776 0.1417599 0.9591495

The output shows that the expected predicted mortality risk for men is substantially higher than women at 5 years after baseline.

One variable, moving horizon

What if the effect of a predictor varies over time? PD can show this.


pd_sex_tv <- orsf_pd_oob(fit, pred_spec = list(sex = c("m", "f")),
                         pred_horizon = seq(365, 365*5))

ggplot(pd_sex_tv, aes(x = pred_horizon, y = mean, color = sex)) + 
 geom_line() +
 labs(x = 'Time since baseline',
      y = 'Expected risk')

From inspection, we can see that males have higher risk than females and the difference in that risk grows over time. This can also be seen by viewing the ratio of expected risk over time:


library(data.table)

ratio_tv <- pd_sex_tv[
 , .(ratio = mean[sex == 'm'] / mean[sex == 'f']), by = pred_horizon
]

ggplot(ratio_tv, aes(x = pred_horizon, y = ratio)) + 
 geom_line(color = 'grey') + 
 geom_smooth(color = 'black', se = FALSE) + 
 labs(x = 'time since baseline',
      y = 'ratio in expected risk for males versus females')
#> `geom_smooth()` using method = 'gam' and formula 'y ~ s(x, bs = "cs")'

Multiple variables, marginally

If you want to compute PD marginally for multiple variables, just list the variable values in pred_spec and specify expand_grid = FALSE.


pd_two_vars <-  
 orsf_pd_oob(fit,
             pred_spec = list(sex = c("m", "f"), bili = 1:5),
             expand_grid = FALSE)

pd_two_vars
#>    pred_horizon variable value  level      mean        lwr      medn       upr
#>           <num>   <char> <num> <char>     <num>      <num>     <num>     <num>
#> 1:      1826.25      sex    NA      m 0.3527805 0.03974647 0.2414356 0.9444124
#> 2:      1826.25      sex    NA      f 0.2932127 0.01115776 0.1417599 0.9591495
#> 3:      1826.25     bili     1   <NA> 0.2340055 0.01296442 0.1171454 0.8702835
#> 4:      1826.25     bili     2   <NA> 0.2851655 0.03823950 0.1735778 0.9065770
#> 5:      1826.25     bili     3   <NA> 0.3431764 0.06659531 0.2542600 0.9238422
#> 6:      1826.25     bili     4   <NA> 0.3944246 0.10249873 0.3127158 0.9380414
#> 7:      1826.25     bili     5   <NA> 0.4409193 0.14114665 0.3826310 0.9464794

Now would it be tedious if you wanted to do this for all the variables? You bet. That’s why we made a function for that. As a bonus, the printed output is sorted from most to least important variables.


pd_smry <- orsf_summarize_uni(fit)

pd_smry
#> 
#> -- bili (VI Rank: 1) -------------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>             0.80 0.2290016 0.1148601 0.04327406 0.3528169
#>              1.4 0.2494666 0.1376638 0.05627886 0.3771831
#>              3.5 0.3704706 0.2889571 0.15656217 0.5533343
#> 
#> -- age (VI Rank: 2) --------------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>               42 0.2704463 0.1295873 0.04026421 0.4548720
#>               50 0.2984843 0.1625962 0.04439149 0.5203701
#>               57 0.3281206 0.2112682 0.06424226 0.5718942
#> 
#> -- protime (VI Rank: 3) ----------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>               10 0.2808684 0.1480830 0.04454335 0.5250508
#>               11 0.2942878 0.1494827 0.04801608 0.5385037
#>               11 0.3174015 0.1848578 0.06474437 0.5605459
#> 
#> -- ascites (VI Rank: 4) ----------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>                0 0.2934430 0.1454747 0.04363005 0.5425820
#>                1 0.4700569 0.3779503 0.27402738 0.6487942
#> 
#> -- sex (VI Rank: 5) --------------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>                m 0.3527805 0.2414356 0.10979219 0.5900754
#>                f 0.2932127 0.1417599 0.04241247 0.5316942
#> 
#> -- stage (VI Rank: 6) ------------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>                1 0.2617610 0.1370758 0.03951652 0.4547312
#>                2 0.2704346 0.1393500 0.03812598 0.4974763
#>                3 0.2920387 0.1591362 0.04597153 0.5402843
#>                4 0.3395445 0.2171918 0.08134507 0.5906341
#> 
#> -- copper (VI Rank: 7) -----------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>               43 0.2610046 0.1359129 0.04000928 0.4773423
#>               74 0.2790426 0.1495856 0.05226547 0.5081557
#>              129 0.3335202 0.2176733 0.10000455 0.5464926
#> 
#> -- spiders (VI Rank: 8) ----------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>                0 0.2900268 0.1442807 0.04234382 0.5274333
#>                1 0.3349321 0.2164031 0.07881800 0.5616628
#> 
#> -- albumin (VI Rank: 9) ----------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>              3.3 0.3139304 0.1689580 0.04677649 0.5803556
#>              3.5 0.2923041 0.1431321 0.04390161 0.5342228
#>              3.8 0.2789748 0.1319210 0.04560136 0.4855513
#> 
#> -- edema (VI Rank: 10) -----------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>                0 0.2900221 0.1454747 0.04216106 0.5435162
#>              0.5 0.3506983 0.2437012 0.09370586 0.6102717
#>                1 0.4375162 0.3381648 0.22289789 0.6563385
#> 
#> -- ast (VI Rank: 11) -------------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>               82 0.2831976 0.1388635 0.04124336 0.5293067
#>              117 0.2969493 0.1541120 0.04724622 0.5472410
#>              153 0.3174977 0.1715826 0.06233281 0.5819240
#> 
#> -- trig (VI Rank: 12) ------------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>               85 0.2940047 0.1479727 0.04129360 0.5354742
#>              108 0.3004809 0.1559818 0.04400752 0.5448964
#>              151 0.3132524 0.1782903 0.05679427 0.5430465
#> 
#> -- alk.phos (VI Rank: 13) --------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>              922 0.3018497 0.1551374 0.04519584 0.5536559
#>             1278 0.3026628 0.1630816 0.04488732 0.5525483
#>             2068 0.3065994 0.1665361 0.05124448 0.5511105
#> 
#> -- hepato (VI Rank: 14) ----------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>                0 0.2854740 0.1450785 0.03921781 0.5230005
#>                1 0.3183191 0.1777718 0.06326351 0.5503241
#> 
#> -- chol (VI Rank: 15) ------------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>              250 0.2851192 0.1443172 0.03853663 0.4941613
#>              310 0.2927041 0.1602519 0.04444628 0.5196398
#>              401 0.3158389 0.1882467 0.07194796 0.5447747
#> 
#> -- platelet (VI Rank: 16) --------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>              200 0.3059458 0.1594746 0.04880371 0.5761188
#>              257 0.3009282 0.1541210 0.04620275 0.5619634
#>              318 0.2977349 0.1524793 0.04545151 0.5500751
#> 
#> -- trt (VI Rank: 17) -------------------------------------
#> 
#>                  |---------------- risk ----------------|
#>            Value      Mean    Median     25th %    75th %
#>           <char>     <num>     <num>      <num>     <num>
#>  d_penicill_main 0.3056548 0.1689712 0.04896262 0.5482101
#>          placebo 0.2996642 0.1451760 0.04475777 0.5498884
#> 
#>  Predicted risk at time t = 1826.25 for top 17 predictors

It’s easy enough to turn this ‘summary’ object into a data.table for downstream plotting and tables.


head(as.data.table(pd_smry))
#>    variable importance  Value      Mean    Median     25th %    75th %
#>      <char>      <num> <char>     <num>     <num>      <num>     <num>
#> 1:     bili 0.01854553   0.80 0.2290016 0.1148601 0.04327406 0.3528169
#> 2:     bili 0.01854553    1.4 0.2494666 0.1376638 0.05627886 0.3771831
#> 3:     bili 0.01854553    3.5 0.3704706 0.2889571 0.15656217 0.5533343
#> 4:      age 0.00896020     42 0.2704463 0.1295873 0.04026421 0.4548720
#> 5:      age 0.00896020     50 0.2984843 0.1625962 0.04439149 0.5203701
#> 6:      age 0.00896020     57 0.3281206 0.2112682 0.06424226 0.5718942
#>    pred_horizon  level
#>           <num> <char>
#> 1:      1826.25   <NA>
#> 2:      1826.25   <NA>
#> 3:      1826.25   <NA>
#> 4:      1826.25   <NA>
#> 5:      1826.25   <NA>
#> 6:      1826.25   <NA>

Multiple variables, jointly

PD can show the expected value of a model’s predictions as a function of a specific predictor, or as a function of multiple predictors. For instance, we can estimate predicted risk as a joint function of bili, edema, and trt:


pred_spec = list(bili = seq(1, 5, length.out = 20),
               edema = levels(pbc_orsf_train$edema),
               trt = levels(pbc_orsf$trt))

pd_bili_edema <- orsf_pd_oob(fit, pred_spec)

library(ggplot2)

ggplot(pd_bili_edema, aes(x = bili, y = medn, col = trt, linetype = edema)) + 
 geom_line() + 
 labs(y = 'Expected predicted risk')

From inspection,

Individual conditional expectations (ICE)

Unlike partial dependence, which shows the expected prediction as a function of one or multiple predictors, individual conditional expectations (ICE) show the prediction for an individual observation as a function of a predictor.

Just like PD, we can compute ICE using in-bag, out-of-bag, or testing data, and the same principles apply. We’ll use out-of-bag estimates here.

Visualizing ICE curves

Inspecting the ICE curves for each observation can help identify whether there is heterogeneity in a model’s predictions. I.e., does the effect of the variable follow the same pattern for all the data, or are there groups where the variable impacts risk differently?

I am going to turn off boundary checking in orsf_ice_oob by setting boundary_checks = FALSE, and this will allow me to generate ICE curves that go beyond the 90th percentile of bili.


pred_spec <- list(bili = seq(1, 10, length.out = 25))

ice_oob <- orsf_ice_oob(fit, pred_spec, boundary_checks = FALSE)

ice_oob
#>       pred_horizon id_variable id_row  bili       pred
#>              <num>       <int>  <int> <num>      <num>
#>    1:      1826.25           1      1     1 0.91930989
#>    2:      1826.25           1      2     1 0.10293802
#>    3:      1826.25           1      3     1 0.68484802
#>    4:      1826.25           1      4     1 0.31939985
#>    5:      1826.25           1      5     1 0.07395545
#>   ---                                                 
#> 6896:      1826.25          25    272    10 0.45183506
#> 6897:      1826.25          25    273    10 0.46421948
#> 6898:      1826.25          25    274    10 0.48277410
#> 6899:      1826.25          25    275    10 0.38452767
#> 6900:      1826.25          25    276    10 0.51664794

For plots, it is helpful to scale the ICE data. I subtract the initial value of predicted risk (i.e., when bili = 1) from each observation’s conditional expectation values. So,

Now we can visualize the curves.


library(ggplot2)

ggplot(ice_oob, aes(x = bili, 
                    y = pred, 
                    group = id_row)) + 
 geom_line(alpha = 0.15) + 
 labs(y = 'Change in predicted risk') +
 geom_smooth(se = FALSE, aes(group = 1))
#> `geom_smooth()` using method = 'gam' and formula 'y ~ s(x, bs = "cs")'

From inspection of the figure,