Partial dependence (PD) shows the expected prediction from a model as a function of a single predictor or multiple predictors. The expectation is marginalized over the values of all other predictors, giving something like a multivariable adjusted estimate of the model’s prediction.
Begin by fitting an ORSF ensemble. Set a prediction horizon of 5
years when we fit the ensemble so that any aorsf
function
that we pass this ensemble to will assume we want to compute predictions
at 5 years.
library(aorsf)
<- 365.25 * 5
pred_horizon
set.seed(329730)
<- sample(nrow(pbc_orsf), 150)
index_train
<- pbc_orsf[index_train, ]
pbc_orsf_train <- pbc_orsf[-index_train, ]
pbc_orsf_test
<- orsf(data = pbc_orsf_train,
fit formula = Surv(time, status) ~ . - id,
oobag_pred_horizon = pred_horizon)
fit#> ---------- Oblique random survival forest
#>
#> Linear combinations: Accelerated
#> N observations: 150
#> N events: 52
#> N trees: 500
#> N predictors total: 17
#> N predictors per node: 5
#> Average leaves per tree: 12
#> Min observations in leaf: 5
#> Min events in leaf: 1
#> OOB stat value: 0.83
#> OOB stat type: Harrell's C-statistic
#> Variable importance: anova
#>
#> -----------------------------------------
You can compute PD three ways with aorsf
:
using in-bag predictions for the training data
<- orsf_pd_inb(fit, pred_spec = list(bili = 1:5))
pd_inb
pd_inb#> pred_horizon bili mean lwr medn upr
#> <num> <int> <num> <num> <num> <num>
#> 1: 1826.25 1 0.2065186 0.01461416 0.09406926 0.8053158
#> 2: 1826.25 2 0.2352372 0.02673697 0.12477942 0.8206148
#> 3: 1826.25 3 0.2754197 0.04359767 0.17630939 0.8406553
#> 4: 1826.25 4 0.3303309 0.09237920 0.24319095 0.8544871
#> 5: 1826.25 5 0.3841395 0.15224112 0.30174988 0.8663482
using out-of-bag predictions for the training data
<- orsf_pd_oob(fit, pred_spec = list(bili = 1:5))
pd_oob
pd_oob#> pred_horizon bili mean lwr medn upr
#> <num> <int> <num> <num> <num> <num>
#> 1: 1826.25 1 0.2075896 0.01389732 0.09063976 0.7998756
#> 2: 1826.25 2 0.2352634 0.02628113 0.12935779 0.8152149
#> 3: 1826.25 3 0.2750782 0.04254451 0.18877830 0.8371582
#> 4: 1826.25 4 0.3302680 0.08806724 0.24827784 0.8441472
#> 5: 1826.25 5 0.3846734 0.14808075 0.29926304 0.8562432
using predictions for a new set of data
<- orsf_pd_new(fit,
pd_test new_data = pbc_orsf_test,
pred_spec = list(bili = 1:5))
pd_test#> pred_horizon bili mean lwr medn upr
#> <num> <int> <num> <num> <num> <num>
#> 1: 1826.25 1 0.2541661 0.01581296 0.1912170 0.8103449
#> 2: 1826.25 2 0.2824737 0.03054392 0.2304441 0.8413602
#> 3: 1826.25 3 0.3205550 0.04959123 0.2736161 0.8495418
#> 4: 1826.25 4 0.3743186 0.10474085 0.3501337 0.8619464
#> 5: 1826.25 5 0.4258793 0.16727203 0.4032790 0.8626002
in-bag PD indicates relationships that the model has learned during training. This is helpful if your goal is to interpret the model.
out-of-bag PD indicates relationships that the model has learned during training but using the out-of-bag data simulates application of the model to new data. if you want to test your model’s reliability or fairness in new data but you don’t have access to a large testing set.
new data PD shows how the model predicts outcomes for observations it has not seen. This is helpful if you want to test your model’s reliability or fairness.
Let’s re-fit our ORSF to all available data before proceeding to the next sections.
set.seed(329730)
<- orsf(pbc_orsf,
fit Surv(time, status) ~ . -id,
oobag_pred_horizon = pred_horizon)
Computing PD for a single variable is straightforward:
<- orsf_pd_oob(fit, pred_spec = list(sex = c("m", "f")))
pd_sex
pd_sex#> pred_horizon sex mean lwr medn upr
#> <num> <fctr> <num> <num> <num> <num>
#> 1: 1826.25 m 0.3527805 0.03974647 0.2414356 0.9444124
#> 2: 1826.25 f 0.2932127 0.01115776 0.1417599 0.9591495
The output shows that the expected predicted mortality risk for men is substantially higher than women at 5 years after baseline.
What if the effect of a predictor varies over time? PD can show this.
<- orsf_pd_oob(fit, pred_spec = list(sex = c("m", "f")),
pd_sex_tv pred_horizon = seq(365, 365*5))
ggplot(pd_sex_tv, aes(x = pred_horizon, y = mean, color = sex)) +
geom_line() +
labs(x = 'Time since baseline',
y = 'Expected risk')
From inspection, we can see that males have higher risk than females and the difference in that risk grows over time. This can also be seen by viewing the ratio of expected risk over time:
library(data.table)
<- pd_sex_tv[
ratio_tv ratio = mean[sex == 'm'] / mean[sex == 'f']), by = pred_horizon
, .(
]
ggplot(ratio_tv, aes(x = pred_horizon, y = ratio)) +
geom_line(color = 'grey') +
geom_smooth(color = 'black', se = FALSE) +
labs(x = 'time since baseline',
y = 'ratio in expected risk for males versus females')
#> `geom_smooth()` using method = 'gam' and formula 'y ~ s(x, bs = "cs")'
If you want to compute PD marginally for multiple variables, just
list the variable values in pred_spec
and specify
expand_grid = FALSE
.
<-
pd_two_vars orsf_pd_oob(fit,
pred_spec = list(sex = c("m", "f"), bili = 1:5),
expand_grid = FALSE)
pd_two_vars#> pred_horizon variable value level mean lwr medn upr
#> <num> <char> <num> <char> <num> <num> <num> <num>
#> 1: 1826.25 sex NA m 0.3527805 0.03974647 0.2414356 0.9444124
#> 2: 1826.25 sex NA f 0.2932127 0.01115776 0.1417599 0.9591495
#> 3: 1826.25 bili 1 <NA> 0.2340055 0.01296442 0.1171454 0.8702835
#> 4: 1826.25 bili 2 <NA> 0.2851655 0.03823950 0.1735778 0.9065770
#> 5: 1826.25 bili 3 <NA> 0.3431764 0.06659531 0.2542600 0.9238422
#> 6: 1826.25 bili 4 <NA> 0.3944246 0.10249873 0.3127158 0.9380414
#> 7: 1826.25 bili 5 <NA> 0.4409193 0.14114665 0.3826310 0.9464794
Now would it be tedious if you wanted to do this for all the variables? You bet. That’s why we made a function for that. As a bonus, the printed output is sorted from most to least important variables.
<- orsf_summarize_uni(fit)
pd_smry
pd_smry#>
#> -- bili (VI Rank: 1) -------------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> 0.80 0.2290016 0.1148601 0.04327406 0.3528169
#> 1.4 0.2494666 0.1376638 0.05627886 0.3771831
#> 3.5 0.3704706 0.2889571 0.15656217 0.5533343
#>
#> -- age (VI Rank: 2) --------------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> 42 0.2704463 0.1295873 0.04026421 0.4548720
#> 50 0.2984843 0.1625962 0.04439149 0.5203701
#> 57 0.3281206 0.2112682 0.06424226 0.5718942
#>
#> -- protime (VI Rank: 3) ----------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> 10 0.2808684 0.1480830 0.04454335 0.5250508
#> 11 0.2942878 0.1494827 0.04801608 0.5385037
#> 11 0.3174015 0.1848578 0.06474437 0.5605459
#>
#> -- ascites (VI Rank: 4) ----------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> 0 0.2934430 0.1454747 0.04363005 0.5425820
#> 1 0.4700569 0.3779503 0.27402738 0.6487942
#>
#> -- sex (VI Rank: 5) --------------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> m 0.3527805 0.2414356 0.10979219 0.5900754
#> f 0.2932127 0.1417599 0.04241247 0.5316942
#>
#> -- stage (VI Rank: 6) ------------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> 1 0.2617610 0.1370758 0.03951652 0.4547312
#> 2 0.2704346 0.1393500 0.03812598 0.4974763
#> 3 0.2920387 0.1591362 0.04597153 0.5402843
#> 4 0.3395445 0.2171918 0.08134507 0.5906341
#>
#> -- copper (VI Rank: 7) -----------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> 43 0.2610046 0.1359129 0.04000928 0.4773423
#> 74 0.2790426 0.1495856 0.05226547 0.5081557
#> 129 0.3335202 0.2176733 0.10000455 0.5464926
#>
#> -- spiders (VI Rank: 8) ----------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> 0 0.2900268 0.1442807 0.04234382 0.5274333
#> 1 0.3349321 0.2164031 0.07881800 0.5616628
#>
#> -- albumin (VI Rank: 9) ----------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> 3.3 0.3139304 0.1689580 0.04677649 0.5803556
#> 3.5 0.2923041 0.1431321 0.04390161 0.5342228
#> 3.8 0.2789748 0.1319210 0.04560136 0.4855513
#>
#> -- edema (VI Rank: 10) -----------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> 0 0.2900221 0.1454747 0.04216106 0.5435162
#> 0.5 0.3506983 0.2437012 0.09370586 0.6102717
#> 1 0.4375162 0.3381648 0.22289789 0.6563385
#>
#> -- ast (VI Rank: 11) -------------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> 82 0.2831976 0.1388635 0.04124336 0.5293067
#> 117 0.2969493 0.1541120 0.04724622 0.5472410
#> 153 0.3174977 0.1715826 0.06233281 0.5819240
#>
#> -- trig (VI Rank: 12) ------------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> 85 0.2940047 0.1479727 0.04129360 0.5354742
#> 108 0.3004809 0.1559818 0.04400752 0.5448964
#> 151 0.3132524 0.1782903 0.05679427 0.5430465
#>
#> -- alk.phos (VI Rank: 13) --------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> 922 0.3018497 0.1551374 0.04519584 0.5536559
#> 1278 0.3026628 0.1630816 0.04488732 0.5525483
#> 2068 0.3065994 0.1665361 0.05124448 0.5511105
#>
#> -- hepato (VI Rank: 14) ----------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> 0 0.2854740 0.1450785 0.03921781 0.5230005
#> 1 0.3183191 0.1777718 0.06326351 0.5503241
#>
#> -- chol (VI Rank: 15) ------------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> 250 0.2851192 0.1443172 0.03853663 0.4941613
#> 310 0.2927041 0.1602519 0.04444628 0.5196398
#> 401 0.3158389 0.1882467 0.07194796 0.5447747
#>
#> -- platelet (VI Rank: 16) --------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> 200 0.3059458 0.1594746 0.04880371 0.5761188
#> 257 0.3009282 0.1541210 0.04620275 0.5619634
#> 318 0.2977349 0.1524793 0.04545151 0.5500751
#>
#> -- trt (VI Rank: 17) -------------------------------------
#>
#> |---------------- risk ----------------|
#> Value Mean Median 25th % 75th %
#> <char> <num> <num> <num> <num>
#> d_penicill_main 0.3056548 0.1689712 0.04896262 0.5482101
#> placebo 0.2996642 0.1451760 0.04475777 0.5498884
#>
#> Predicted risk at time t = 1826.25 for top 17 predictors
It’s easy enough to turn this ‘summary’ object into a
data.table
for downstream plotting and tables.
head(as.data.table(pd_smry))
#> variable importance Value Mean Median 25th % 75th %
#> <char> <num> <char> <num> <num> <num> <num>
#> 1: bili 0.01854553 0.80 0.2290016 0.1148601 0.04327406 0.3528169
#> 2: bili 0.01854553 1.4 0.2494666 0.1376638 0.05627886 0.3771831
#> 3: bili 0.01854553 3.5 0.3704706 0.2889571 0.15656217 0.5533343
#> 4: age 0.00896020 42 0.2704463 0.1295873 0.04026421 0.4548720
#> 5: age 0.00896020 50 0.2984843 0.1625962 0.04439149 0.5203701
#> 6: age 0.00896020 57 0.3281206 0.2112682 0.06424226 0.5718942
#> pred_horizon level
#> <num> <char>
#> 1: 1826.25 <NA>
#> 2: 1826.25 <NA>
#> 3: 1826.25 <NA>
#> 4: 1826.25 <NA>
#> 5: 1826.25 <NA>
#> 6: 1826.25 <NA>
PD can show the expected value of a model’s predictions as a function
of a specific predictor, or as a function of multiple predictors. For
instance, we can estimate predicted risk as a joint function of
bili
, edema
, and trt
:
= list(bili = seq(1, 5, length.out = 20),
pred_spec edema = levels(pbc_orsf_train$edema),
trt = levels(pbc_orsf$trt))
<- orsf_pd_oob(fit, pred_spec)
pd_bili_edema
library(ggplot2)
ggplot(pd_bili_edema, aes(x = bili, y = medn, col = trt, linetype = edema)) +
geom_line() +
labs(y = 'Expected predicted risk')
From inspection,
the model’s predictions indicate slightly lower risk for the
placebo group, and these do not seem to change much at different values
of bili
or edema
.
There is a clear increase in predicted risk with higher levels of
edema
and with higher levels of bili
the slope of predicted risk as a function of bili
appears highest among patients with edema
of 0.5. Is the
effect of bili
modified by edema
being 0.5? A
quick sanity check with coxph
suggests there is.
library(survival)
$edema_05 <- ifelse(pbc_orsf$edema == '0.5', 'yes', 'no')
pbc_orsf
<- coxph(Surv(time,status) ~ edema_05 * bili,
fit_cph data = pbc_orsf)
anova(fit_cph)
#> Analysis of Deviance Table
#> Cox model: response is Surv(time, status)
#> Terms added sequentially (first to last)
#>
#> loglik Chisq Df Pr(>|Chi|)
#> NULL -550.19
#> edema_05 -546.83 6.7248 1 0.009508 **
#> bili -513.59 66.4689 1 3.555e-16 ***
#> edema_05:bili -510.54 6.1112 1 0.013433 *
#> ---
#> Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Unlike partial dependence, which shows the expected prediction as a function of one or multiple predictors, individual conditional expectations (ICE) show the prediction for an individual observation as a function of a predictor.
Just like PD, we can compute ICE using in-bag, out-of-bag, or testing data, and the same principles apply. We’ll use out-of-bag estimates here.
Inspecting the ICE curves for each observation can help identify whether there is heterogeneity in a model’s predictions. I.e., does the effect of the variable follow the same pattern for all the data, or are there groups where the variable impacts risk differently?
I am going to turn off boundary checking in orsf_ice_oob
by setting boundary_checks = FALSE
, and this will allow me
to generate ICE curves that go beyond the 90th percentile of
bili
.
<- list(bili = seq(1, 10, length.out = 25))
pred_spec
<- orsf_ice_oob(fit, pred_spec, boundary_checks = FALSE)
ice_oob
ice_oob#> pred_horizon id_variable id_row bili pred
#> <num> <int> <int> <num> <num>
#> 1: 1826.25 1 1 1 0.91930989
#> 2: 1826.25 1 2 1 0.10293802
#> 3: 1826.25 1 3 1 0.68484802
#> 4: 1826.25 1 4 1 0.31939985
#> 5: 1826.25 1 5 1 0.07395545
#> ---
#> 6896: 1826.25 25 272 10 0.45183506
#> 6897: 1826.25 25 273 10 0.46421948
#> 6898: 1826.25 25 274 10 0.48277410
#> 6899: 1826.25 25 275 10 0.38452767
#> 6900: 1826.25 25 276 10 0.51664794
id_variable
is an identifier for the current value
of the variable(s) that are in the data. It is redundant if you only
have one variable, but helpful if there are multiple variables.
id_row
is an identifier for the observation in the
original data. It is used to group an observation’s predictions together
in plots.
For plots, it is helpful to scale the ICE data. I subtract the
initial value of predicted risk (i.e., when bili = 1
) from
each observation’s conditional expectation values. So,
Every curve start at 0
The plot shows change in predicted risk as a function of
bili
.
:= pred - pred[bili==1], by = id_row] ice_oob[, pred
Now we can visualize the curves.
library(ggplot2)
ggplot(ice_oob, aes(x = bili,
y = pred,
group = id_row)) +
geom_line(alpha = 0.15) +
labs(y = 'Change in predicted risk') +
geom_smooth(se = FALSE, aes(group = 1))
#> `geom_smooth()` using method = 'gam' and formula 'y ~ s(x, bs = "cs")'
From inspection of the figure,
Most of the individual slopes cluster around the overall trend - Good!
A small number of individual slopes appear to be flat. It may be helpful to investigate this further.