--- title: "[en] An example of recursive partitioning with Titanic data" author: "Nicolas Robette" date: "`r Sys.Date()`" output: rmdformats::html_clean: thumbnails: FALSE use_bookdown: FALSE vignette: > %\VignetteIndexEntry{[en] An example of recursive partitioning with Titanic data} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r setup, echo=FALSE, cache=FALSE} library(knitr) library(rmdformats) ## Global options options(max.print="75") opts_chunk$set(echo=TRUE, cache=FALSE, prompt=FALSE, tidy=FALSE, comment=NA, message=FALSE, warning=FALSE) opts_knit$set(width=75) ``` ```{r load_res, include=FALSE} load(url('http://nicolas.robette.free.fr/Docs/results_titanic.RData')) ``` ## First steps First, the necessary packages are loaded into memory. ```{r init, cache=FALSE} library(dplyr) # data management library(caret) # confusion matrix library(party) # conditional inference random forests and trees library(partykit) # conditional inference trees library(pROC) # ROC curves library(measures) # performance measures library(varImp) # variable importance library(pdp) # partial dependence library(vip) # measure of interactions library(moreparty) # surrogate trees, accumulated local effects, etc. library(RColorBrewer) # color palettes library(descriptio) # bivariate analysis ``` Now, we then import `titanic` data set from [`moreparty`](https://cran.r-project.org/package=moreparty). ```{r import_tita} data(titanic) str(titanic) ``` We have 1309 cases, one categorical explained variable, `Survived`, which codes whether or not an individual survived the shipwreck, and four explanatory variables (three categorical and one continuous): gender, age, passenger class, and port of embarkation. The distribution of the variables is examined. ```{r desc_tita} summary(titanic) ``` The distribution of the explained variable is not balanced, as survival is largely in the minority. In addition, some explanatory variables have missing values, in particular `Age`. We examine the bivariate statistical relationships between the variables. ```{r bivar_assoc} BivariateAssoc(titanic$Survived, titanic[,-1]) ``` Survival is primarily associated with gender, secondarily with the passenger class. The explanatory variables are weakly related to each other. ```{r catdesc} catdesc(titanic$Survived, titanic[,-1], limit = 0.1, robust = FALSE, na.rm.cont = TRUE) ``` Women, first class passengers and those who boarded at Cherbourg are over-represented among the survivors. Men, 3rd class passengers and those who boarded at Southampton are over-represented among the non-survivors. Random forests imply a share of randomness (via resampling and drawing of splitting variables), as well as some interpretation tools (via variable permutations). From one program run to the next, the results may therefore differ slightly. If you wish to obtain the same results systematically and to ensure reproducibility, use the `set.seed` function. ```{r seed} set.seed(1912) ``` ## Classification tree In order to build a classification tree with CTree conditional inference algorithm, we use [`partykit`](https://cran.r-project.org/package=partykit) package, which allows more flexibility than [`party`](https://cran.r-project.org/package=party) package, in particular to deal with missing values. The tree can be displayed in textual or graphical form. ```{r ctree, out.width='100%'} arbre <- partykit::ctree(Survived~., data=titanic, control=partykit::ctree_control(minbucket=30, maxsurrogate=Inf, maxdepth=3)) print(arbre) plot(arbre) ``` `Sex` is the first splitting variable, `Pclass` is the second and all the explanatory variables are used in the tree. The proportion of survivors varies greatly from one terminal node to another. ```{r proba_nodes} nodeapply(as.simpleparty(arbre), ids = nodeids(arbre, terminal = TRUE), FUN = function(x) round(prop.table(info_node(x)$distribution),3)) ``` Thus, 96.5% of women travelling in 1st class survive, compared with only 12.4% of men over 9 years of age travelling in 2nd or 3rd class. The graphical representation can be parameterized to obtain a simpler and more readable tree. ```{r ctree_plot, out.width='100%'} plot(arbre, inner_panel=node_inner(arbre,id=FALSE,pval=FALSE), terminal_panel=node_barplot(arbre,id=FALSE), gp=gpar(cex=0.6), ep_args=list(justmin=15)) ``` Note that the [`ggparty`](https://cran.r-project.org/package=ggparty) package graphically represents Ctree trees with the `ggplot2` grammar. To measure the performance of the tree, the AUC is calculated, by comparing predicted and observed survival. ```{r pred_tree} pred_arbre <- predict(arbre, type='prob')[,'Yes'] auc_arbre <- AUC(pred_arbre, titanic$Survived, positive='Yes') auc_arbre %>% round(3) ``` AUC is `r round(auc_arbre,3)`, which is a relatively high performance. To plot the ROC curve of the model : ```{r roc_tree, fig.align="center", fig.width=4, fig.height=4} pROC::roc(titanic$Survived, pred_arbre) %>% ggroc(legacy.axes=TRUE) + geom_segment(aes(x=0,xend=1,y=0,yend=1), color="darkgrey", linetype="dashed") + theme_bw() + xlab("TFP") + ylab("TVP") ``` Other performance measures are based on the confusion matrix. ```{r confusion} ifelse(pred_arbre > .5, "Yes", "No") %>% factor %>% caret::confusionMatrix(titanic$Survived, positive='Yes') ``` The `GetSplitStats` function allows to examine the result of the competition between covariates in the choice of splitting variables. If, for a given node, the splitting variable is significantly more associated with the explained variable than the other explanatory variables, we can think that this split is stable. ```{r split_stats} GetSplitStats(arbre) ``` In this case, for each of the nodes, the result of the competition is final (see `criterion`). The tree therefore appears to be stable. ## Random forest Then we build a random forest with conditional inference algorithm, `mtry`=2 and `ntree`=500. ```{r forest} foret <- party::cforest(Survived~., data=titanic, controls=party::cforest_unbiased(mtry=2,ntree=500)) ``` To compare the performance of the forest to that of the tree, the predictions and then the AUC are calculated. ```{r forest_pred} pred_foret <- predict(foret, type='prob') %>% do.call('rbind.data.frame',.) %>% select(2) %>% unlist auc_foret <- AUC(pred_foret, titanic$Survived, positive='Yes') auc_foret %>% round(3) ``` The performance of the forest is `r round(auc_foret,3)`, therefore slightly better than that of the tree. The `OOB=TRUE` option allows predictions to be made from out-of-bag observations, thus avoiding optimism bias. ```{r forest_pred_OOB} pred_oob <- predict(foret, type='prob', OOB=TRUE) %>% do.call('rbind.data.frame',.) %>% select(2) %>% unlist auc_oob <- AUC(pred_oob, titanic$Survived, positive='Yes') auc_oob %>% round(3) ``` Calculated in this way, the performance is indeed slightly lower (`r round(auc_oob,3)`). ## Surrogate tree The so-called "surrogate tree" can be a way to synthesize a complex model. ```{r surrogate, out.width='100%'} surro <- SurrogateTree(foret, maxdepth=3) surro$r.squared %>% round(3) plot(surro$tree, inner_panel=node_inner(surro$tree,id=FALSE,pval=FALSE), terminal_panel=node_boxplot(surro$tree,id=FALSE), gp=gpar(cex=0.6), ep_args=list(justmin=15)) ``` The surrogate tree reproduces here very faithfully the predictions of the random forest (R2 = `r round(surro$r.squared,3)`). It is also very similar to the initial classification tree. ## Variable importance To go further in the interpretation of the results, permutation variable importance is calculated (using AUC as a performance measure). Since the explanatory variables have little correlation with each other, it is not necessary to use the "conditional permutation scheme" (available with the `conditional=TRUE` option in the `varImpAUC` function). ```{r vimp, fig.align="center", fig.width=5, fig.height=3} importance <- -varImpAUC(foret) importance %>% round(3) ggVarImp(-importance) ``` `Sex` is the most important variable, ahead of `Pclass`, `Age` and `Embarked` (whose importance is close to zero). ## First order effects The calculation of partial dependences of all the covariates can then be performed using the `GetPartialData` function. ```{r pdp, eval=FALSE} pdep <- GetPartialData(foret, which.class=2, probs=1:19/20, prob=TRUE) ``` ```{r pdp2} pdep ``` They can be represented graphically. ```{r pdp_plot, fig.align="center", fig.width=5, fig.height=6} ggForestEffects(pdep, vline=mean(pred_foret), xlab="Probability of survival") + xlim(c(0,1)) ``` The dashed line represents the average predicted probability of survival. We can see for example that the probability of survival is much higher for women or that it decreases when we go from 1st to 3rd class. We're zooming in on age to see its effect more closely. Partial dependencies are calculated for 40 quantiles of the age distribution, this high number allowing a more precise examination. ```{r pdp_age, eval=FALSE} pdep_age <- pdp::partial(foret, 'Age', which.class=2, prob=TRUE, quantiles=TRUE, probs=1:39/40) ``` ```{r pdp_plot_age, fig.align="center", fig.width=5, fig.height=3} ggplot(pdep_age, aes(x=Age, y=yhat)) + geom_line() + geom_hline(aes(yintercept=mean(pred_foret)), size=0.2, linetype='dashed', color='black') + ylim(c(0,1)) + theme_bw() + ylab("Probability of survival") ``` The probability of survival for young children is high, but drops rapidly with age. It is around the average for young adults and declines again after age 30. We often limit ourselves to examining *average* survival probabilities, but we can also look at their *distribution*, in numerical form : ```{r pdp_in, eval=FALSE} pdep_ind <- GetPartialData(foret, which.class=2, probs=1:19/20, prob=TRUE, ice=TRUE) ``` ```{r pdp_table} pdep_ind %>% group_by(var, cat) %>% summarise(prob = mean(value) %>% round(3), Q1 = quantile(value, 0.25) %>% round(3), Q3 = quantile(value, 0.75) %>% round(3)) ``` Or in graphical form : ```{r pdp_boxplot, fig.align="center", fig.width=5, fig.height=5} ggplot(pdep_ind, aes(x = value, y = cat, group = cat)) + geom_boxplot(aes(fill=var), notch=TRUE) + geom_vline(aes(xintercept=median(pred_foret)), size=0.2, linetype='dashed', color='black') + facet_grid(var ~ ., scales = "free_y", space = "free_y") + theme_bw() + theme(panel.grid = element_blank(), panel.grid.major.y = element_line(size=.1, color="grey70"), legend.position = "none", strip.text.y = element_text(angle = 0)) + xlim(c(0,1)) + xlab("Probability of survival") + ylab("") ``` The accumulated local effects (ALE) of all explanatory variables can be calculated simply with the `GetAleData` function. ```{r ale, eval=FALSE} ale <- GetAleData(foret) ``` ```{r ale2} ale ``` The effects are then represented graphically. These are very convergent with the partial dependences, which is not surprising given that the explanatory variables here have little correlation. ```{r ale_plot, fig.align="center", fig.width=5, fig.height=6} ggForestEffects(ale) ``` ## Interactions We now wish to identify the main 2nd order interactions, using the algorithm of Greenwell et al (2018) available with `GetInteractionStrength` (a simple wrapper for `vint` function in [`vip`](https://cran.r-project.org/package=vip) package). ```{r vint, eval=FALSE} vint <- GetInteractionStrength(foret) ``` ```{r vint2} vint ``` The interaction between Sex and Pclass is by far the most pronounced. The main second order interactions are then studied using partial dependence, with `partial` function from [`pdp`](https://cran.r-project.org/package=pdp) package. ```{r pd_inter2_sexclass, eval=FALSE} pdep_sexclass <- pdp::partial(foret, c('Sex','Pclass'), quantiles=TRUE, probs=1:19/20, which.class=2L, prob=TRUE) ``` ```{r pd_plot_inter2_sexclass, fig.align="center", fig.width=5, fig.height=3} ggplot(pdep_sexclass, aes(Pclass, yhat)) + geom_point(aes(color=Sex)) + ylim(0,1) + theme_bw() ``` Pclass has a greater effect on women than on men. In women, it opposes the 1st and 2nd classes to the 3rd; in men, the 1st class opposes the others. ```{r pd_inter2_sexage, eval=FALSE} pdep_sexage <- pdp::partial(foret, c('Sex','Age'), quantiles=TRUE, probs=1:19/20, which.class=2L, prob=TRUE) ``` ```{r pd_plot_inter2_sexage, fig.align="center", fig.width=5, fig.height=3} ggplot(pdep_sexage, aes(Age, yhat)) + geom_line(aes(color=Sex)) + ylim(0,1) + theme_bw() ``` Age has little effect on survival for females, while it plays a more pronounced role for males, with young boys having a higher probability of survival than adult men. An interaction of order 3 using partial dependence : ```{r pd_inter3, eval=FALSE} pdep_sexclassage <- pdp::partial(foret, c('Sex','Pclass','Age'), quantiles=TRUE, probs=1:19/20, which.class=2L, prob=TRUE) ``` ```{r pd_plot_inter3, eval=FALSE} cols <- c(paste0('dodgerblue',c(4,3,1)),paste0('tomato',c(4,3,1))) pdep_sexclassage %>% data.frame %>% mutate(sexclass = interaction(Pclass,Sex)) %>% ggplot(aes(x=Age, y=yhat)) + geom_line(aes(colour=sexclass)) + scale_color_manual(values=cols) + ylim(0,1) + theme_bw() ``` ```{r pd_plot_inter3bis, echo=FALSE, fig.align="center", out.width='70%'} knitr::include_graphics("http://nicolas.robette.free.fr/Docs/plot_inter3.png") ``` For example, we see that the age effect plays mainly for 2nd and 3rd class males. Alternatively, second order interactions can be analyzed using accumulated local effects (excluding interactions between two categorical variables). ```{r ale_inter2, eval=FALSE} ale_sex_age = GetAleData(foret, xnames=c("Sex","Age"), order=2) ``` ```{r ale_plot_inter2, fig.align="center", fig.width=5, fig.height=3} ale_sex_age %>% ggplot(aes(Age, value)) + geom_line(aes(color=Sex)) + geom_hline(yintercept=0, linetype=2, color='gray60') + theme_bw() ``` The additional effect of the interaction between `Sex` and `Age` (relative to the main effects of these variables) is mainly present for children: being a boy is associated with a higher probability of survival and vice versa for girls. ## Prototypes and outliers The prototypes are observations that are representative of their class. The calculation is based on the proximity matrix between observations. ```{r prototypes} prox <- proximity(foret) proto <- Prototypes(titanic$Survived, titanic[,-1], prox) proto ``` The prototypes of survivors are all adult women traveling first class. The prototypes of non-survivors are all adult males. The proximity matrix also makes it possible to identify outliers. ```{r outliers1, fig.align="center", fig.width=4, fig.height=4} out <- bind_cols(pred=round(pred_foret,2),titanic) %>% Outliers(prox, titanic$Survived, .) boxplot(out$scores) ``` Only a few observations have a score above 10 (the threshold suggested by Breiman). ```{r outliers2} arrange(out$outliers, Survived, desc(scores)) %>% split(.$Survived) ``` Outliers among survivors are women travelling in 1st or 2nd class, whose predicted probability of survival is very high. Conversely, outliers among non-survivors are males travelling in 2nd or 3rd class, whose predicted probability of survival is low. ## Feature selection Variable selection is of little use in this case, where there are only four explanatory variables. However, as an illustration, the "recursive feature elimination" algorithm is applied. Note that this procedure can be time consuming, even in its parallelized version (especially for the 'ALT' and 'HAPF' algorithms). ```{r featsel, eval=FALSE} featsel <- FeatureSelection(titanic$Survived, titanic[,-1], method="RFE", positive="Yes") ``` ```{r featsel2} featsel$selection.0se featsel$selection.1se ``` The algorithm suggests to keep the set of variables, or to eliminate `Age` and `Embarked` variables if one is ready to lose some performance ("1 standard error rule"). ## Parallelization [`moreparty`](https://cran.r-project.org/package=moreparty) package provides parallelized versions of `cforest`, `varImp` and `varImpAUC` functions, to save computation time. For example, to parallelize the computation of variable importance : ```{r parallel, eval=FALSE} library(doParallel) registerDoParallel(cores=2) fastvarImpAUC(foret) stopImplicitCluster() ``` This chunk of code works fine with Linux but may have to be adapted for other OS. Parallelization is also an option for many of the functions we have used here.