| Literature DB >> 35911882 |
Helard Becerra Martinez1, Katryna Cisek2, Alejandro García-Rudolph3,4,5, John D Kelleher2, Andrew Hines1.
Abstract
Accurate early predictions of a patient's likely cognitive improvement as a result of a stroke rehabilitation programme can assist clinicians in assembling more effective therapeutic programs. In addition, sufficient levels of explainability, which can justify these predictions, are a crucial requirement, as reported by clinicians. This article presents a machine learning (ML) prediction model targeting cognitive improvement after therapy for stroke surviving patients. The prediction model relies on electronic health records from 201 ischemic stroke surviving patients containing demographic information, cognitive assessments at admission from 24 different standardized neuropsychology tests (e.g., TMT, WAIS-III, Stroop, RAVLT, etc.), and therapy information collected during rehabilitation (72,002 entries collected between March 2007 and September 2019). The study population covered young-adult patients with a mean age of 49.51 years and only 4.47% above 65 years of age at the stroke event (no age filter applied). Twenty different classification algorithms (from Python's Scikit-learn library) are trained and evaluated, varying their hyper-parameters and the number of features received as input. Best-performing models reported Recall scores around 0.7 and F1 scores of 0.6, showing the model's ability to identify patients with poor cognitive improvement. The study includes a detailed feature importance report that helps interpret the model's inner decision workings and exposes the most influential factors in the cognitive improvement prediction. The study showed that certain therapy variables (e.g., the proportion of memory and orientation executed tasks) had an important influence on the final prediction of the cognitive improvement of patients at individual and population levels. This type of evidence can serve clinicians in adjusting the therapeutic settings (e.g., type and load of therapy activities) and selecting the one that maximizes cognitive improvement.Entities:
Keywords: AI explainability; cognitive improvement; cognitive therapy; ischemic stroke; machine learning (ML); predictive models; web-based therapy
Year: 2022 PMID: 35911882 PMCID: PMC9325998 DOI: 10.3389/fneur.2022.886477
Source DB: PubMed Journal: Front Neurol ISSN: 1664-2295 Impact factor: 4.086
Figure 1Diagram of the methodology used in this study.
List of neuropsychological assessments administrated at the Guttmann Institut at admission and discharge.
|
|
|
|
|---|---|---|
| TB Personal Orientation ( | Orientation | [0–7] |
| TB Spatial Orientation ( | Orientation | [0–5] |
| TB Temporal Orientation ( | Orientation | [0–23] |
| Digits Span ( | Attention | [0–9] |
| TMT-A ( | Attention | [0–Inf] |
| Stroop-Words ( | Attention | [0–Inf] |
| Stroop-Color ( | Attention | [0–Inf] |
| Stroop-Words/Colors ( | Attention | [0–Inf] |
| TB Language Repetition ( | Language | [0–10] |
| TB Language Denomination ( | Language | [0–14] |
| TB Language Comprehension ( | Language | [0–16] |
| Digit Span Backwards WAIS-III ( | Memory | [0–8] |
| Numbers and Letters WAIS-III ( | Memory | [0–16] |
| RAVLT Learning ( | Memory | [0–75] |
| RAVLT Free Recall ( | Memory | [0–15] |
| RAVLT Recognition ( | Memory | [0–15] |
| TMT-B ( | Executive Functions | [0–Inf] |
| WCST Categories ( | Executive Functions | [0–6] |
| WCST Errors ( | Executive Functions | [0–Inf] |
| Stroop-Interference ( | Executive Functions | [Inf] |
| PMR ( | Executive Functions | [0–Inf] |
| Visuospatial WAIS-III ( | Visual | [0–Inf] |
| Images WAIS-III ( | Visual | [0–20] |
| Cubes WAIS-III ( | Visual | [0–Inf] |
| (*) NIHSS ( | Overall impairment | [0–42] |
TB, Test Barcelona; TMT, Trail Making Test; WAIS-III, Wechsler Adult Intelligence Scale 3rd version; RAVLT, Rey Auditory Verbal Learning Test; WCST, Wisconsin Card Sorting; NIHSS, National Institutes of Health Stroke Scale; (*), administrated only at admission.
Sets of hyper-parameters used for the grid-search corresponding to the pre-selected classification algorithms.
|
|
|
|---|---|
| ExtraTreesClassifier | n_estimators = [10, 100, 1,000] |
| max_depth = [3, 7, 9] | |
| min_samples_split=[2,10,20] | |
| criterion=[“gini,” “entropy”] | |
| min_weight_fraction_leaf = [0,0.2,0.3,0.5] | |
| RandomForestClassifier | criterion=[“gini,” “entropy”] |
| n_estimators = [10, 100, 1,000] | |
| max_features = [“sqrt,” “log2”] | |
| max_depth = [9, 15] | |
| min_samples_split=[2,10,20] | |
| min_weight_fraction_leaf = [0,0.2,0.5] | |
| KNeighborsClassifier | n_neighbors = [2, 10, 21] |
| weights = [“uniform,” “distance”] | |
| metric = [“euclidean,” “manhattan,” “minkowski”] | |
| XGBClassifier | eta = [0.001, 0.01, 0.1, 0.2, 0.3] |
| gamma = [0.05, 0.5, 1, 1.5] | |
| min_child_weight = [5, 7, 9, 10] | |
| subsample = [0.5, 0.8, 1] | |
| colsample_bytree = [0.6, 0.8, 1] | |
| lambda_par = [0.1, 0.5, 1] | |
| LogisticRegression | solver = [“newton-cg,” “lbfgs,” “liblinear”] |
| penalty = [“L2”] | |
| C = [100, 10, 1.0, 0.1, 0.01] | |
| max_iter = [1,000] |
Basic statistics of demographic, cognitive and therapy variables from the rehabilitation dataset.
|
|
|
|
|
|
| ||||
|---|---|---|---|---|---|---|---|---|---|
| Demographic | |||||||||
| Age | 49.70 | 10.18 | 16.74 | 81.92 | |||||
| Age at injury (mc) | 49.52 | 10.18 | 16.65 | 81.84 | |||||
| Time since injury in days | 66.46 | 40.30 | 1.60 | 173.43 | |||||
| Length of therapy | 65.35 | 30.79 | 15.00 | 173.00 | |||||
| Sex (c) | |||||||||
| Male | N/A | N/A | N/A | N/A | |||||
| Female | N/A | N/A | N/A | N/A | |||||
| Marital status (c) | |||||||||
| Married | N/A | N/A | N/A | N/A | |||||
| Single | N/A | N/A | N/A | N/A | |||||
| Divorce | N/A | N/A | N/A | N/A | |||||
| Separate | N/A | N/A | N/A | N/A | |||||
| Widow | N/A | N/A | N/A | N/A | |||||
| Cognitive | |||||||||
| Admission compliance | 0.90 | 0.11 | 0.67 | 1.00 | |||||
| Discharge compliance (mc) | 0.94 | 0.09 | 0.67 | 1.00 | |||||
| Global improvement (t) | 0.18 | 0.21 | -0.50 | 0.75 | |||||
| Attention improvement (mc) | 0.02 | 0.17 | -0.33 | 0.67 | |||||
| Orientation improvement (mc) | 0.27 | 0.40 | -0.80 | 1.00 | |||||
| Language improvement (mc) | 0.04 | 0.15 | -0.33 | 0.67 | |||||
| Memory improvement (mc) | 0.27 | 0.44 | -1.00 | 1.00 | |||||
| Ex. Functions improvement (mc) | 0.23 | 0.39 | -1.00 | 1.00 | |||||
| Visual improvement | 0.31 | 0.41 | -0.67 | 1.00 | |||||
| NIHSS | 9.86 | 4.65 | 1.00 | 22.00 | |||||
|
|
|
|
|
| |||||
|
|
|
|
|
|
|
|
| ||
| Orientation | TB Personal Orientation | 6.99 | 7.00 | 0.12 | 0.00 | 6 | 7 | 7 | 7 |
| TB Spatial Orientation | 4.97 | 4.99 | 0.17 | 0.10 | 4 | 4 | 5 | 5 | |
| TB Temporal Orientation | 22.59 | 22.71 | 1.41 | 1.22 | 12 | 11 | 23 | 23 | |
| Attention | Digits Span | 5.91 | 6.03 | 1.10 | 1.07 | 3 | 4 | 9 | 9 |
| TMT-A | 67.44 | 53.42 | 48.43 | 31.70 | 4 | 6 | 289 | 240 | |
| Stroop - Words (md) | 78.43 | 82.05 | 16.37 | 15.51 | 37 | 40 | 123 | 125 | |
| Stroop - Color (md) | 55.51 | 57.70 | 12.45 | 12.54 | 23 | 27 | 89 | 90 | |
| Stroop - Words/Colors (md) | 32.04 | 34.03 | 10.71 | 11.04 | 6 | 8 | 85 | 73 | |
| Language | TB Language Repetition | 9.99 | 10.00 | 0.10 | 0.00 | 9 | 10 | 10 | 10 |
| TB Language Denomination | 13.96 | 14.00 | 0.27 | 0.00 | 11 | 14 | 14 | 14 | |
| TB Language Comprehension | 15.78 | 15.91 | 0.76 | 0.50 | 9 | 12 | 16 | 16 | |
| Memory | Digit Span Backwards WAIS-III | 4.19 | 4.35 | 0.98 | 0.98 | 2 | 2 | 7 | 8 |
| Numbers and Letters WAIS-III | 8.14 | 8.72 | 2.61 | 2.52 | 1 | 3 | 14 | 15 | |
| RAVLT Learning | 42.14 | 46.23 | 10.66 | 11.49 | 21 | 9 | 70 | 70 | |
| RAVLT Free Recall | 8.26 | 9.35 | 3.51 | 3.41 | 0 | 0 | 15 | 15 | |
| RAVLT Recognition | 11.44 | 12.19 | 3.91 | 3.20 | 0 | 1 | 15 | 15 | |
| Executive Functions | TMT-B (md) | 141.13 | 112.07 | 82.51 | 48.63 | 30 | 30 | 565 | 300 |
| WCST Categories (md) | 4.11 | 4.24 | 2.11 | 2.13 | 0 | 0 | 6 | 6 | |
| WCST Errors (md) | 18.36 | 16.25 | 15.07 | 14.64 | 0 | 0 | 63 | 72 | |
| Stroop - Interference (md) | -0.21 | 0.41 | 7.20 | 7.48 | -21 | -22 | 35 | 25 | |
| PMR | 31.93 | 35.10 | 13.25 | 13.30 | 3 | 5 | 72 | 84 | |
| Visual | Visuospatial WAIS-III (md) | 41.21 | 46.23 | 15.59 | 16.04 | 10 | 13 | 92 | 92 |
| Images WAIS-III | 19.30 | 19.67 | 1.49 | 0.94 | 11 | 14 | 20 | 20 | |
| Cubes WAIS-III | 26.96 | 30.31 | 11.97 | 11.94 | 2 | 6 | 66 | 66 | |
|
|
|
|
|
| |||||
| Therapy | |||||||||
| Daily sessions | 12.15 | 7.91 | 1.00 | 52.00 | |||||
| Total number of tasks (mc) | 111.89 | 85.22 | 2.00 | 480.00 | |||||
| Total non executed tasks (mc) | 15.84 | 18.93 | 0.00 | 117.00 | |||||
| Non executed proportion | 0.14 | 0.10 | 0.00 | 0.53 | |||||
| Total gain proportion (mc) | 0.53 | 0.08 | 0.25 | 0.75 | |||||
| Attention | Number of tasks (mc) | 20.93 | 18.36 | 1.00 | 123.00 | ||||
| Task proportion | 0.19 | 0.10 | 0.01 | 0.67 | |||||
| Non executed tasks | 2.01 | 3.85 | 0.00 | 40.00 | |||||
| Execution gain | 11.25 | 10.22 | 0.00 | 72.00 | |||||
| Memory | Number of tasks (mc) | 45.65 | 40.40 | 2.00 | 294.00 | ||||
| Task proportion | 0.40 | 0.15 | 0.07 | 1.00 | |||||
| Non executed tasks | 5.22 | 8.08 | 0.00 | 53.00 | |||||
| Execution gain | 25.57 | 23.53 | 0.50 | 181.00 | |||||
| Ex. Functions | Number of tasks (mc) | 38.16 | 30.58 | 1.00 | 170.00 | ||||
| Task proportion | 0.35 | 0.13 | 0.05 | 0.75 | |||||
| Non executed tasks | 8.12 | 9.13 | 0.00 | 62.00 | |||||
| Execution gain | 18.08 | 15.17 | 0.00 | 96.50 | |||||
| Language | Number of tasks (mc) | 8.86 | 13.35 | 1.00 | 55.00 | ||||
| Task proportion | 0.16 | 0.28 | 0.01 | 1.00 | |||||
| Non executed tasks | 0.71 | 1.10 | 0.00 | 3.00 | |||||
| Execution gain | 5.04 | 9.10 | 0.50 | 37.00 | |||||
| Orientation | Number of tasks (mc) | 4.31 | 5.27 | 1.00 | 31.00 | ||||
| Task proportion | 0.03 | 0.04 | 0.00 | 0.20 | |||||
| Non executed tasks | 0.39 | 1.30 | 0.00 | 10.00 | |||||
| Execution gain | 2.06 | 2.58 | 0.00 | 15.00 | |||||
| Calculus | Number of tasks (mc) | 11.81 | 11.45 | 1.00 | 62.00 | ||||
| Task proportion | 0.09 | 0.07 | 0.01 | 0.31 | |||||
| Non executed tasks | 1.42 | 2.32 | 0.00 | 11.00 | |||||
| Execution gain | 6.25 | 6.29 | 0.00 | 36.50 | |||||
| Gnosias | Number of tasks (mc) | 8.50 | 14.68 | 1.00 | 81.00 | ||||
| Task proportion | 0.06 | 0.09 | 0.01 | 0.49 | |||||
| Non executed tasks | 0.47 | 1.47 | 0.00 | 8.00 | |||||
| Execution gain | 4.81 | 8.11 | 0.00 | 40.00 | |||||
| Praxias | Number of tasks (mc) | 3.71 | 2.99 | 2.00 | 12.00 | ||||
| Task proportion | 0.02 | 0.01 | 0.01 | 0.07 | |||||
| Non executed tasks | 0.24 | 0.64 | 0.00 | 2.00 | |||||
| Execution gain | 1.82 | 1.74 | 0.00 | 6.50 | |||||
N = 201; Adm, admission; Dis, discharge; c, categorical variable; mc, removed to prevent multicollinearity issues; md, removed to prevent missing data issues; t, target variable.
Summary report from temporal records of the GNPT platform.
|
|
|
|---|---|
| [0] | 11,669 (16.2%) |
| [1–64] | 23,501 (32.6%) |
| [65–85] | 13,573 (18.9%) |
| [86–100] | 23,259 (32.3%) |
| Attention | 14,015 (19.5%) |
| Memory | 28,963 (40.2%) |
| Ex. Functions | 21,172 (29.4%) |
| Language | 1,923 (2.7%) |
| Orientation | 979 (1.4%) |
| Calculus | 2,377 (3.3%) |
| Gnosias | 2,379 (3.3%) |
| Praxias | 158 (0.2%) |
Figure 2Correlation outcomes from demographic + cognitive variables. Red: strong positive correlation (+1), blue: strong negative correlation (–1).
Figure 3Correlation outcomes from demographic + therapy variables. Red: strong positive correlation (+1), blue: strong negative correlation (–1).
Figure 4Cluster analysis using PCA and t-SNE over demographic, cognitive, and therapy variables. Data points labeled according to severity (NIHSS score) and improvement (global improvement). PCA: accumulated explained variance 0.42 (A). t-SNE: perplexity 10, learning rate 200 (B).
Figure 5Scores distribution of global improvement. Class “0”: global improvement <=0, Class “1”: global improvement>0.
Evaluation of the classification algorithms without hyper-parameter tuning.
|
|
|
|
|
|
|---|---|---|---|---|
| RandomForestClassifier | 0.638 (0.08) |
| 0.652 (0.08) | 0.526 (0.05) |
| ExtraTreesClassifier | 0.630 (0.08) |
| 0.640 (0.11) | 0.528 (0.05) |
| KNeighborsClassifier | 0.614 (0.10) |
| 0.606 (0.13) | 0.512 (0.06) |
| XGBClassifier | 0.636 (0.08) |
| 0.642 (0.08) | 0.543 (0.07) |
| LogisticRegression | 0.643 (0.06) |
| 0.646 (0.08) | 0.549 (0.08) |
| RidgeClassifier | 0.633 (0.05) |
| 0.637 (0.07) | 0.534 (0.07) |
| BaggingClassifier | 0.610 (0.08) |
| 0.632 (0.06) | 0.533 (0.07) |
| LinearSVC | 0.634 (0.06) |
| 0.639 (0.07) | 0.544 (0.07) |
| LinearDiscriminant Analysis | 0.614 (0.06) |
| 0.616 (0.07) | 0.515 (0.07) |
| BernoulliNB | 0.611 (0.06) |
| 0.626 (0.07) | 0.516 (0.07) |
Results sorted by the mean Recall scores accompanied by their corresponding SD. Re-sampling at k-fold (k = 5) cross-validation with 5 repetitions. The bold values indicated to highlight the recall columns.
Best performing sets of hyper-parameters gathered during the grid-search optimization process.
|
|
|
|---|---|
| ExtraTreesClassifier_All | criterion: gini, max_depth: 9, min_samples_split: 2, min_weight_fraction_leaf: 0, n_estimators: 1,000 |
| ExtraTreesClassifier_20 | criterion: gini, max_depth: 3, min_samples_split: 2, min_weight_fraction_leaf: 0.2, n_estimators: 10 |
| ExtraTreesClassifier_10 | criterion: gini, max_depth: 3, min_samples_split: 2, min_weight_fraction_leaf: 0.2, n_estimators: 10 |
| RandomForestClassifier_All | criterion: entropy, max_depth: 15, max_features: log2, min_samples_split: 10, min_weight_fraction_leaf: 0, n_estimators: 1,000 |
| RandomForestClassifier_20 | criterion: entropy, max_depth: 15, max_features: sqrt, min_samples_split: 20, min_weight_fraction_leaf: 0.2, n_estimators: 10 |
| RandomForestClassifier_10 | criterion: gini, max_depth: 15, max_features: log2, min_samples_split: 2, min_weight_fraction_leaf: 0.2, n_estimators: 10 |
| KNeighborsClassifier_All | metric: euclidean, n_neighbors: 17, weights: distance |
| KNeighborsClassifier_20 | metric: manhattan, n_neighbors: 19, weights: uniform |
| KNeighborsClassifier_10 | metric: euclidean, n_neighbors: 19, weights: uniform |
| XGBClassifier_All | colsample_bytree: 0.6, eta: 0.01, gamma: 1, min_child_weight: 5, reg_lambda: 0.5, subsample: 0.8 |
| XGBClassifier_20 | colsample_bytree: 0.8, eta: 0.1, gamma: 0.5, min_child_weight: 9, reg_lambda: 0.1, subsample: 0.5 |
| XGBClassifier_10 | colsample_bytree: 0.6, eta: 0.001, gamma: 0.05, min_child_weight: 5, reg_lambda: 0.1, subsample: 0.5 |
| LogisticRegression_All | C: 0.01, max_iter: 300, penalty: l2, solver: newton-cg |
| LogisticRegression_20 | C: 0.01, max_iter: 300, penalty: l2, solver: newton-cg |
| LogisticRegression_10 | C: 0.01, max_iter: 300, penalty: l2, solver: newton-cg |
Recall scores as evaluation criteria. Results reported for the different number of input features. “All”: no feature selection, “20”: best 20 features, “10”: best 10 features. Re-sampling at k-fold (k = 5) cross-validation with five repetitions.
Evaluation of the pre-selected algorithms with hyper-parameters tuning.
|
|
|
|
|
|
|---|---|---|---|---|
| ExtraTreesClassifier_All | 0.637 (0.09) |
| 0.640 (0.15) | 0.536 (0.05) |
| ExtraTreesClassifier_20 | 0.593 (0.08) |
| 0.510 (0.08) | 0.500 (0.00) |
| ExtraTreesClassifier_10 | 0.593 (0.08) |
| 0.510 (0.08) | 0.500 (0.00) |
| RandomForestClassifier_All | 0.612 (0.08) |
| 0.610 (0.13) | 0.512 (0.03) |
| RandomForestClassifier_20 | 0.597 (0.08) |
| 0.515 (0.09) | 0.502 (0.01) |
| RandomForestClassifier_10 | 0.593 (0.08) |
| 0.508 (0.08) | 0.500 (0.01) |
| KNeighborsClassifier_All | 0.606 (0.07) |
| 0.596 (0.13) | 0.508 (0.02) |
| KNeighborsClassifier_20 | 0.624 (0.08) |
| 0.654 (0.12) | 0.522 (0.04) |
| KNeighborsClassifier_10 | 0.609 (0.09) |
| 0.591 (0.13) | 0.507 (0.04) |
| XGBClassifier_All | 0.617 (0.09) |
| 0.624 (0.16) | 0.517 (0.04) |
| XGBClassifier_20 | 0.608 (0.10) |
| 0.556 (0.15) | 0.516 (0.04) |
| XGBClassifier_10 | 0.593 (0.08) |
| 0.510 (0.08) | 0.500 (0.00) |
| LogisticRegression_All | 0.607 (0.07) |
| 0.595 (0.13) | 0.507 (0.03) |
| LogisticRegression_20 | 0.607 (0.07) |
| 0.595 (0.13) | 0.507 (0.03) |
| LogisticRegression_10 | 0.604 (0.08) |
| 0.567 (0.13) | 0.507 (0.02) |
Recall scores as evaluation criteria. Results reported for the different number of input features. “All”: no feature selection, “20”: best 20 features, “10”: best 10 features. Re-sampling at k-fold (k = 5) cross-validation with 5 repetitions The bold values indicated to highlight the recall columns.
Figure 6Global feature importance plot based on mean SHAP values for sex (A) and age (B) cohorts of patients. Results from the XGBClassifier_All optimized model.
Figure 7Dependence scatter plots showing the effect of six input features in the model predictions. SHAP values (y-axis) pushes the outcome toward a class “0” (no cognitive improvement, negative SHAP values) or a class “1” (cognitive improvement, positive SHAP values). Normalized instances of features with corresponding histograms depicted in the x-axis. (A) Time since injury, (B) Length of therapy, (C) Non executed proportion, (D) memory gain, (E) memory non executed, and (F) memory task proportion.
Figure 8SHAP waterfall plots for explanations of individual predictions. Illustrative cases of two patients with the highest global improvement scores. Positive feature effects are represented in red and negative effects in blue. The plot should be read from E[f(x)], the expected value of the model output, toward f(x), the model output. Positive outcomes in the x-axis (in log-odds units) are probabilities above 0.5 of classifying a patient as a cognitive improvement case (class “1”). (A) Patient_123 and (B) Patient_133.
Figure 9SHAP waterfall plots for explanations of individual predictions. Illustrative cases of two patients with the lowest global improvement scores. Positive feature effects are represented in red and negative effects in blue. Plot should be read from E[f(x)], expected value of the model output, toward f(x), the model output. Negative outcomes in the x-axis (in log-odds units) are probabilities above 0.5 of classifying a patient as a no cognitive improvement case (class “0”). (A) Patient_27 and (B) Patient_33.