| Literature DB >> 32251471 |
Esra Zihni1, Vince Istvan Madai1, Michelle Livne1, Ivana Galinovic2, Ahmed A Khalil2, Jochen B Fiebach2, Dietmar Frey1.
Abstract
State-of-the-art machine learning (ML) artificial intelligence methods are increasingly leveraged in clinical predictive modeling to provide clinical decision support systems to physicians. Modern ML approaches such as artificial neural networks (ANNs) and tree boosting often perform better than more traditional methods like logistic regression. On the other hand, these modern methods yield a limited understanding of the resulting predictions. However, in the medical domain, understanding of applied models is essential, in particular, when informing clinical decision support. Thus, in recent years, interpretability methods for modern ML methods have emerged to potentially allow explainable predictions paired with high performance. To our knowledge, we present in this work the first explainability comparison of two modern ML methods, tree boosting and multilayer perceptrons (MLPs), to traditional logistic regression methods using a stroke outcome prediction paradigm. Here, we used clinical features to predict a dichotomized 90 days post-stroke modified Rankin Scale (mRS) score. For interpretability, we evaluated clinical features' importance with regard to predictions using deep Taylor decomposition for MLP, Shapley values for tree boosting and model coefficients for logistic regression. With regard to performance as measured by Area under the Curve (AUC) values on the test dataset, all models performed comparably: Logistic regression AUCs were 0.83, 0.83, 0.81 for three different regularization schemes; tree boosting AUC was 0.81; MLP AUC was 0.83. Importantly, the interpretability analysis demonstrated consistent results across models by rating age and stroke severity consecutively amongst the most important predictive features. For less important features, some differences were observed between the methods. Our analysis suggests that modern machine learning methods can provide explainability which is compatible with domain knowledge interpretation and traditional method rankings. Future work should focus on replication of these findings in other datasets and further testing of different explainability methods.Entities:
Mesh:
Year: 2020 PMID: 32251471 PMCID: PMC7135268 DOI: 10.1371/journal.pone.0231166
Source DB: PubMed Journal: PLoS One ISSN: 1932-6203 Impact factor: 3.240
Summary of the clinical data.
| Clinical information | Value |
|---|---|
| Median age (IQR) | 72 (15) |
| Sex (Females/ Males) | 196 / 118 |
| Median initial NIHSS (IQR) | 3 (5) |
| Cardiac history (yes/ no) | 84 / 230 |
| Diabetes mellitus (yes/ no) | 79 / 235 |
| Hypercholesterolemia (yes/ no) | 182 / 132 |
| Thrombolysis (yes / no) | 74 / 240 |
The table summarizes the distribution of the selected clinical data covariates acquired in the acute clinical setting. NIHSS stands for National Institutes of Health Stroke Scale; IQR indicates the interquartile range.
Summary of hyperparameters tuning.
| Model | Hyperparameter | Range |
|---|---|---|
| LASSO | C (inverse of regularizer multiplier) | 0.10, 0.12, 0.15, 0.18, 0.21, 0.26, 0.31, 0.37, 0.45, 0.54, 0.66, 0.79, 0.95, 1.15, 1.39, 1.68, 2.02, 2.44, 2.95, 3.56, 4.29, 5.18, 6.25, 7.54, 9.10,10.9, 13.3, 16.0, 19.3, 23.3, 28.1, 33.9, 40.9, 49.4, 59.6, 72.0, 86.9, 105, 126, 153, 184, 222, 268, 324, 391, 471, 569, 687, 829, 1000 |
| Elastic net | L1 ratio | 0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95 |
| Alpha | 0.00001, 0.00004, 0.00016, 0.0006, 0.0025, 0.01, 0.04, 0.16, 0.63, 2.5, 10 | |
| CatBoost | Tree depth | 2, 4 |
| Learning rate | 0.03, 0.1, 0.3 | |
| Bagging temperature | 0.6, 0.8, 1. | |
| L2 leaf regularization | 3, 10, 100, 500 | |
| Leaf estimation iterations | 1, 2 | |
| MLP | Number of hidden neurons | 5, 10, 15, 20 |
| Learning rate | 0.001, 0.01 | |
| Batch size | 16, 32 | |
| Dropout rate | 0.1, 0.2 | |
| L1 regularization ratio | 0.0001, 0.001 |
The table details the hyperparameters and corresponding range that were tuned for each model in the cross-validation process.
Fig 1Graphical representation of the model performance results.
The graph illustrates the performance of the different models evaluated on the training (blue) and test (orange) sets: generalized linear model (GLM), Lasso, Elastic net, Tree Boosting and multilayer perceptron (MLP). The markers show the median AUC over 50 shuffles and the error bars represent interquartile range (IQR). All models showed a similar median AUC around 0.82. The largest difference in performance between training and test set, indicating potential overfitting, was observed for the Catboost model.
Fig 2Graphical representation of the feature importance.
The figure illustrates the features rating derived from the model-tailored interpretability methods for generalized linear model (GLM), Lasso, Elastic net, Catboost and multilayer perceptron (MLP). All models rated age and initial NIHSS consistently amongst the most important features. For less important features, results were more varied. For logistic regression techniques the results are given in weights, for Catboost in Shap(ley) values and for MLP in deep Taylor values that were normalized to the range [0,1]. The bar heights represent means and error bars represent standard deviation over samples (shuffles).