| Literature DB >> 35494832 |
Amril Nazir1, Hyacinth Kwadwo Ampadu2.
Abstract
The global healthcare system is being overburdened by an increasing number of COVID-19 patients. Physicians are having difficulty allocating resources and focusing their attention on high-risk patients, partly due to the difficulty in identifying high-risk patients early. COVID-19 hospitalizations require specialized treatment capabilities and can cause a burden on healthcare resources. Estimating future hospitalization of COVID-19 patients is, therefore, crucial to saving lives. In this paper, an interpretable deep learning model is developed to predict intensive care unit (ICU) admission and mortality of COVID-19 patients. The study comprised of patients from the Stony Brook University Hospital, with patient information such as demographics, comorbidities, symptoms, vital signs, and laboratory tests recorded. The top three predictors of ICU admission were ferritin, diarrhoea, and alamine aminotransferase, and the top predictors for mortality were COPD, ferritin, and myalgia. The proposed model predicted ICU admission with an AUC score of 88.3% and predicted mortality with an AUC score of 96.3%. The proposed model was evaluated against existing model in the literature which achieved an AUC of 72.8% in predicting ICU admission and achieved an AUC of 84.4% in predicting mortality. It can clearly be seen that the model proposed in this paper shows superiority over existing models. The proposed model has the potential to provide tools to frontline doctors to help classify patients in time-bound and resource-limited scenarios.Entities:
Keywords: COVID-19; Interpretable deep learning; Prediction of ICU admission; Prediction of mortality
Year: 2022 PMID: 35494832 PMCID: PMC9044277 DOI: 10.7717/peerj-cs.889
Source DB: PubMed Journal: PeerJ Comput Sci ISSN: 2376-5992
Summary of existing works.
| Current work | Data sets | Test period | Objective (Covid-19 patients in the hospital) |
|---|---|---|---|
|
| Lombardy, Italy ICU hospital admission | 21 February 2020–27 June 2020 | Predict ICU beds and mortality rate |
|
| Chile official COVID-19 data | May 20th 2020–July 28th 2020 | Forecast in the short-term, ICU beds availability |
|
| Worldwide COVID data from 146 countries | December 1, 2019–February 5th, 2020 | Predict the mortality risk in patients |
|
| São Paulo COVID-19 hospital admission | March 1 2020–28 June 2020 | Predict the risk of developing critical conditions |
|
| Michigan COVID 19 hospital data | 1 February 2020–4 May 2020 | Predict the need for mechanical ventilation and mortality. |
|
| Montefiore Medical Center COVID 19 data | March 1 2020–July 3 2020 | Predict patients’ chances of surviving SARS-CoV-2 infection |
|
| Stony Brook University Hospital COVID hospital data | 7 February 2020–4 May 2020. | Predict ICU admission and in-hospital mortality. |
Figure 1TabNet model architecture.
Figure 2TabNet decision making process.
Notations.
| Notations | Definitions |
|---|---|
| x | Features |
|
| sigmoid |
| n(d) | output decision from current step |
| n(a) | input decision to the next current step |
| W | weights |
| ⊕ | direct summation |
|
| tensor product |
|
| gamma |
|
| beta |
|
| mean |
|
| integral block |
|
| product block |
| P[i−1] | prior scales |
| p[i−1] | split layer division |
|
| FC layer + BN layer |
|
| Mask learning process |
| d[i] | final output |
| a[i] | determine mask of next step |
| f(x) | function to return value of relu function |
|
| features selected at |
|
| importance of features |
| G | Total number of synthetic data examples |
|
| minority class |
|
| majority class |
|
| Desired balance level |
| x′ | new generated synthetic data |
| cov(X, y) | covariance matrix |
| joint probability distribution of features | |
|
| t-distribution of features |
| KL | Kullback–Leiber divergence |
| TP | True positive |
| TN | True negative |
| FP | False positive |
| FN | False negative |
Description of datasets.
| Dataset | No. patients-No. features | Class labels | Class distribution ratio (Pos: Neg) |
|---|---|---|---|
| ICUMice-ICU | 1,106-43 | 1 = death 0 = non-death | 86.1:13.9 |
| DEADMice-Mortality | 1,020-43 | 1 = ICU 0 = no-ICU | 75.5:24.5 |
Relationship between features and ICU admission.
| Features/variables | ICU ( | No-ICU ( |
|---|---|---|
| Demographics | ||
| Age, mean | 59.42 | 62.06 |
| Male | 67.5% (183) | 54% (451) |
| Female | 32.5% (88) | 46% (384) |
| Ethnicity | ||
| Hispanic/Latino | 28.8% (78) | 26.6% (222) |
| Non-Hispanic/Latino | 54.6% (148) | 60.7% (507) |
| Unknown 16.6% (45) | 12.7% (106) | |
| Race | ||
| Caucasian | 45.4% (123) | 54.3% (453) |
| African American | 4.79% (13) | 7.3% (61) |
| American Indian | 0.7% (2) | 0.2% (2) |
| Asian | 7.4% (20) | 3.1% (26) |
| Native Hawaiian | 0 | 0.1% (1) |
| More than one race | 0 | 0.6% (5) |
| Unknown/not reported | 41.7% (113) | 34.4% (287) |
| Comorbidities | ||
| Smoking history | 22.5% (61) | 25.6% (214) |
| Diabetes | 29.5% (80) | 26.3% (220) |
| Hypertension | 46.5% (126) | 49.3% (412) |
| Asthma | 8.5% (23) | 5.1% (43) |
| COPD | 6.3% (17) | 9.1% (76) |
| Coronary artery disease | 14.4% (39) | 15.1% (126) |
| Heart failure | 6.6% (18) | 7.4% (62) |
| Cancer | 5.5% (15) | 10.5% (88) |
| Chronic kidney disease | 7.4% (20) | 9.7% (81) |
| Vital signs | ||
| Systolic blood pressure (mmHg), mean | 124.8 | 128.99 |
| Temperature (degree Celsius), mean | 37.63 | 37.47 |
| Heart rate, mean | 106.1 | 98.2 |
| Respiratory rate (rate/min), mean | 25.28 | 21.77 |
| Laboratory Findings | ||
| Alanine aminotransferase (U/L), mean | 49.62 | 47.03 |
| C-reactive protein (mg/dL), mean | 15.4 | 9.49 |
| D-dimer (ng/mL), mean | 1,101.92 | 1,210.51 |
| Ferritin (ng/mL), mean | 1,469.67 | 1,005.43 |
| Lactase dehydrogenase (U/L), mean | 481.7 | 377.85 |
| Lymphocytes (*1,000/ml) | 12.43 | 14.85 |
| Procalcitonin (ng/mL), mean | 2.66 | 0.97 |
| Troponin (ng/mL), mean | 0.038 | 0.03 |
Relationship between symptoms and ICU admission.
| Symptoms | Percentage of patients with symptoms (%) |
|---|---|
| Fever | 70.5 |
| Cough | 70.5 |
| Shortness of Breath (SOB) | 77.5 |
| Fatigue | 79.3 |
| Sputum | 90.77 |
| Myalgia | 77.5 |
| Diarrhea | 77.9 |
| Nausea or vomiting | 83.3 |
| Sore throat | 92.3 |
| Runny nose or Nasal congestion | 94.8 |
| Loss of smell | 95.9 |
| Loss of Taste | 95.6 |
| Headache | 89.7 |
| Chest discomfort or chest pain | 84.1 |
Correlation between symptoms and ICU admission.
| Symptoms | Correlation (Pearson) | |
|---|---|---|
| Fever | 0.046 | 0.122 |
| Cough | 0.028 | 0.348 |
| Shortness of Breath (SOB) | 0.1 | 0.0008 |
| Fatigue | −0.03 | 0.248 |
| Sputum | 0.055 | 0.065 |
| Myalgia | −0.005 | 0.869 |
| Diarrhea | −0.018 | 0.5415 |
| Nausea or vomiting | −0.035 | 0.247 |
| Sore throat | 0.009 | 0.757 |
| Runny nose or Nasal congestion | 0.0177 | 0.556 |
| Loss of smell | −0.0003 | 0.992 |
| Loss of Taste | −0.012 | 0.689 |
| Headache | 0.013 | 0.673 |
| Chest discomfort or chest pain | −0.0007 | 0.98 |
Relationship between features and mortality.
| Features/variables | Death ( | No-Death ( |
|---|---|---|
| Demographics | ||
| Age, mean | 73 | 59.83 |
| Male | 65.5% (93) | 55% (483) |
| Female | 34.5% (49) | 45% (395) |
| Ethnicity | ||
| Hispanic/Latino | 16.2% (23) | 28.5% (250) |
| Non-Hispanic/Latino | 73.9% (105) | 57.4% (504) |
| Unknown | 9.9% (14) | 14.1% (124) |
| Race | ||
| Caucasian | 64.1% (91) | 51.3% (450) |
| African American | 4.2% (6) | 6.9% (61) |
| Asian | 6.3% (9) | 3.% (33) |
| American Indian | 0.7% (2) | 0.23% (2) |
| Native Hawaiian | 0 | 0.1% (1) |
| More than one race | 0 | 0.6% (5) |
| Unknown/not reported | 24.6% (35) | 37.1% (326) |
| Comorbidities | ||
| Smoking history | 36.6% (52) | 23.2% (204) |
| Diabetes | 33.8% (48) | 26.08% (229) |
| Hypertension | 64.8% (92) | 45.8% (402) |
| Asthma | 4.22% (6) | 5.8% (51) |
| COPD | 16.2% (23) | 7.5% (66) |
| Coronary artery disease | 27.5% (39) | 13.1% (115) |
| Heart failure | 20.4% (29) | 5.4% (47) |
| Cancer | 13.4% (19) | 8.9% (78) |
| Immunosuppression | 5.6% (8) | 7.4% (65) |
| Chronic kidney disease | 14.08% (20) | 8.5% (75) |
| Vital signs | ||
| Systolic blood pressure (mmHg), mean | 127.45 | 128.57 |
| Temperature (degree Celsius), mean | 37.3 | 37.52 |
| Heart rate, mean | 98.28 | 100.38 |
| Respiratory rate (rate/min), mean | 26.39 | 21.79 |
| Laboratory Findings | ||
| Alanine aminotransferase (U/L), mean | 42.91 | 48.45 |
| C-reactive protein (mg/dL), mean | 16.07 | 9.62 |
| D-dimer (ng/mL), mean | 2,626.27 | 1,016 |
| Ferritin (ng/mL), mean | 1,565 | 1,037.5 |
| Lactase dehydrogenase (U/L), mean | 588.28 | 363.14 |
| Lymphocytes (*1,000/ml) | 10.96 | 14.99 |
| Procalcitonin (ng/mL), mean | 5.14 | 0.76 |
| Troponin (ng/mL), mean | 0.07 | 0.0278 |
Relationship between symptoms and mortality.
| Symptoms | Percentage of patients with symptoms (%) |
|---|---|
| Fever | 57 |
| Cough | 51.4 |
| Shortness of Breath (SOB) | 71.8 |
| Fatigue | 86.6 |
| Sputum | 93 |
| Myalgia | 89.44 |
| Diarrhea | 81 |
| Nausea or vomiting | 93 |
| Sore throat | 95.1 |
| Runny nose or Nasal congestion | 97.18 |
| Loss of smell | 98.59 |
| Loss of Taste | 98.59 |
| Headache | 95.07 |
| Chest discomfort or chest pain | 92.96 |
Correlation between symptoms and mortality.
| Symptoms | Correlation (Pearson) | |
|---|---|---|
| Fever | −0.08 | 0.009 |
| Cough | −0.149 | 0.0000069 |
| Shortness of Breath (SOB) | 0.031 | 0.32 |
| Fatigue | −0.09 | 0.003 |
| Sputum | 0.006 | 0.84 |
| Myalgia | −0.119 | 0.00013 |
| Diarrhea | −0.04 | 0.199 |
| Nausea or vomiting | −0.128 | 0.00004116 |
| Sore throat | −0.037 | 0.233 |
| Runny nose or Nasal congestion | −0.03 | 0.318 |
| Loss of smell | −0.05 | 0.0965 |
| Loss of Taste | −0.066 | 0.0377 |
| Headache | 0.062 | 0.048 |
| Chest discomfort or chest pain | −0.1 | 0.002 |
Default hyperparameters of the TabNet Model.
| Training hyper parameters | Default values |
|---|---|
| Max epochs | 200 |
| Batch Size | 1,024 |
| Masking Function | sparsemax |
| Width of decision prediction layer | 8 |
| Patience | 15 |
| momentum | 0.02 |
| n shared | 2 |
| n independent | 2 |
| gamma | 1.3 |
| nsteps | 3 |
| lambda sparse | 1e−3 |
Performance of the most optimized TabNet model with corresponding standard deviations across all the runs.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 72.8 | 55.1 | 72.1 | 76.0 | 43.2 | |
| TabNet Baseline+Fast ICA+ADASYN | 79.77 | 80.09 | 82.1 | 84.47 | 77.01 |
| TabNet Best+ Fast ICA+ ADASYN | 84.66 | 85.73 | 84.52 | 92.31 | 81.28 |
Figure 3Varying hyper parameters with respect to AUC score for predicting ICU admission.
Figure 4Feature importance masks for predicting ICU admission using TabNet (Individual interpretability).
Figure 5Feature importance for predicting ICU admission using TabNet (Global interpretability).
Figure 6Model Loss for best TabNet model.
Figure 7Training and validation accuracy for best TabNet model.
Figure 8Precision-Recall curve of the best TabNet model for predicting ICU admission.
Figure 9Confusion matrix of the best TabNet model for predicting ICU admission.
Comparison of results between proposed method and existing technique.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 72.8 | 55.1 | 72.1 | 76.0 | 43.2 | |
| Proposed method | 88.4 | 89.7 | 88.7 | 93.3 | 86.4 |
Figure 10Varying Hyper parameters with respect to AUC score for predicting mortality.
Performance of the best final TabNet model with FastICA dimensionality reduction method and ADASYN oversampling method.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 84.4 | 61.6 | 85.3 | 70.6 | 52.2 | |
| TabNet Baseline+Fast ICA+ADASYN | 89.03 | 89.12 | 88.92 | 92.98 | 85.75 |
| TabNet Best+Fast ICA+ADASYN | 91.59 | 91.74 | 91.49 | 96.65 | 87.36 |
Figure 11Feature importance masks for predicting mortality (Individual interpretability).
Figure 12Feature importance for predicting mortality (Global Interpretability).
Figure 13Model Loss for best TabNet model.
Figure 14Training and validation accuracy for best TabNet model.
Figure 15Precision-Recall curve of the best TabNet model for predicting mortality.
Figure 16Confusion matrix of the best TabNet model for predicting mortality.
Comparison of results between proposed method and existing technique.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 84.4 | 61.6 | 85.3 | 70.6 | 52.2 | |
| Proposed method | 96.3 | 95.8 | 96.0 | 99.8 | 91.8 |
Varying width of decision prediction layer (nd).
| nd | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 72.8 | 55.1 | 72.1 | 76.0 | 43.2 | |
| TabNet Baseline (default = 8) | 82.9 | 84.9 | 83.3 | 88.7 | 81.4 |
| TabNet with nd = 2 | 76.6 | 80.8 | 77.4 | 89.9 | 73.39 |
| TabNet with nd = 4 | 74.47 | 80.2 | 75.6 | 93.3 | 70.3 |
| TabNet with nd = 16 | 83.27 | 84.3 | 83.3 | 84.3 | 84.3 |
| TabNet with nd = 32 | 83.9 | 86.6 | 84.5 | 94.4 | 80 |
| TabNet with nd = 64 | 84.8 | 86.5 | 85.1 | 89.9 | 83.3 |
Varying number of steps in the architecture (nsteps).
| nsteps | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 72.8 | 55.1 | 72.1 | 76.0 | 43.2 | |
| TabNet Baseline (default nsteps = 3) | 82.9 | 84.9 | 83.3 | 88.7 | 81.4 |
| TabNet with nsteps = 4 | 76.4 | 79.1 | 76.8 | 83.1 | 75.5 |
| TabNet with nsteps = 6 | 50 | 69.3 | 52.9 | 99.5 | 52.9 |
| TabNet with nsteps = 10 | 80.2 | 83.8 | 81 | 93.3 | 76.1 |
Varying gamma.
| gamma | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 72.8 | 55.1 | 72.1 | 76.0 | 43.2 | |
| TabNet Baseline (default gamma = 1.3) | 82.9 | 84.9 | 83.3 | 88.7 | 81.4 |
| TabNet with gamma = 1.5 | 79.5 | 83.4 | 80.4 | 93.3 | 75.5 |
| TabNet with gamma = 1.7 | 77.4 | 80.6 | 77.9 | 86.5 | 75.5 |
| TabNet with gamma = 1.9 | 50 | 69.3 | 52.9 | 100 | 52.9 |
Varying nindependent.
| nindependent | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 72.8 | 55.1 | 72.1 | 76.0 | 43.2 | |
| TabNet Baseline (nindependent = 2) | 82.9 | 84.9 | 83.3 | 88.7 | 81.4 |
| TabNet with nindependent = 3 | 50 | 69.3 | 52.9 | 100 | 52.9 |
| TabNet with nindependent = 4 | 82.8 | 83.4 | 82.7 | 82.0 | 84.9 |
| TabNet with nindependent = 5 | 78.5 | 82.1 | 79.2 | 89.9 | 75.5 |
Varying nshared.
| nshared | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 72.8 | 55.1 | 72.1 | 76.0 | 43.2 | |
| TabNet Baseline (nshared = 2) | 82.9 | 84.9 | 83.3 | 88.7 | 81.4 |
| TabNet with nshared = 3 | 76.8 | 80.2 | 77.4 | 86.5 | 74.8 |
| TabNet with nshared = 4 | 50 | 0 | 47 | 0 | 0 |
| TabNet with nshared = 5 | 82.8 | 85.3 | 83.3 | 91.0 | 80.2 |
Varying momentum.
| momentum | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 72.8 | 55.1 | 72.1 | 76.0 | 43.2 | |
| TabNet Baseline (momentum = 0.02) | 82.9 | 84.9 | 83.3 | 88.7 | 81.3 |
| TabNet with momentum = 0.1 | 83.6 | 85.4 | 83.9 | 88.7 | 82.3 |
| TabNet with momentum = 0.2 | 84.6 | 86.9 | 85.1 | 93.3 | 81.4 |
| TabNet with momentum = 0.3 | 84.6 | 86.8 | 85.1 | 93.3 | 81.4 |
Varying lambda sparse.
| lambda sparse | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 72.8 | 55.1 | 72.1 | 76.0 | 43.2 | |
| TabNet Baseline with lambda sparse = 0.01 | 82.9 | 84.9 | 83.3 | 88.7 | 81.4 |
| TabNet with lambda sparse = 0.01 | 76.6 | 80.8 | 77.4 | 89.9 | 73.4 |
| TabNet with lambda sparse = 0.1 | 82.2 | 84.4 | 82.7 | 91.0 | 79.4 |
Varying mask type.
| mask type | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 72.8 | 55.1 | 72.1 | 76.0 | 43.2 | |
| TabNet Baseline (mask type = sparsemax) | 82.9 | 84.9 | 83.3 | 88.7 | 81.4 |
| TabNet with mask type = entmax | 87.1 | 88.9 | 87.5 | 94.4 | 84.0 |
Impact of number of epochs and stopping condition on the performance of the TabNet architecture.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 72.8 | 55.1 | 72.1 | 76.0 | 43.2 | |
| TabNet Baseline (epoch = 100) | 78.6 | 81.9 | 79.2 | 88.8 | 76.0 |
| TabNet Baseline (epoch = 50) | 55.2 | 71.0 | 57.7 | 97.8 | 55.8 |
| TabNet Baseline (epoch = 150) | 82.6 | 82.4 | 81.9 | 84.8 | 81.3 |
| TabNet Baseline (epoch = 200) | 82.9 | 81.4 | 82.9 | 88.8 | 81.3 |
| TabNet Baseline epoch = 150, patience = 5 | 50 | 69.3 | 52.3 | 100 | 53.0 |
| TabNet Baseline epoch = 150, patience = 15 | 50 | 69.3 | 52.3 | 100 | 53.0 |
| TabNet Baseline epoch = 150, patience = 30 | 82.6 | 83.4 | 83.9 | 88.7 | 81.3 |
| TabNet Baseline epoch = 150 patience = 60 | 83.6 | 85.4 | 83.9 | 88.8 | 82.3 |
| TabNet Baseline epoch = 150, patience = 90 | 83.6 | 85.4 | 83.9 | 88.8 | 82.3 |
| TabNet Baseline epoch = 150, patience = 120 | 83.6 | 85.4 | 83.9 | 88.8 | 82.3 |
Comparison of the prediction performance of our baseline models with different variations of dimensionality reduction methods.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 72.8 | 55.1 | 72.1 | 76.0 | 43.2 | |
| TabNet Baseline + PCA | 81.0 | 83.8 | 81.5 | 89.9 | 78.4 |
| TabNet Baseline + Fast ICA | 83.6 | 85.4 | 83.9 | 88.8 | 82.3 |
| TabNet Baseline + Factor Analysis | 72.8 | 75.9 | 73.2 | 79.8 | 72.4 |
| TabNet Baseline + tSNE | 58.2 | 55.9 | 57.7 | 50.6 | 62.5 |
| TabNet Baseline + UMAP | 54.2 | 42.3 | 53.0 | 32.6 | 60.4 |
Comparison of the prediction performance of our best TabNet models with different variations of dimensionality reduction methods.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 72.8 | 55.1 | 72.1 | 76.0 | 76.0 | |
| TabNet Best + PCA | 86.2 | 88.8 | 86.9 | 97.8 | 81.3 |
| TabNet Best + Fast ICA | 86.4 | 88.5 | 86.9 | 95.5 | 82.5 |
| TabNet Best + Factor Analysis | 82.0 | 85.3 | 82.7 | 94.4 | 77.8 |
| TabNet Best + tSNE | 60.9 | 64.5 | 61.3 | 66.3 | 62.8 |
| TabNet Best + UMAP | 58.0 | 61.5 | 58.3 | 62.9 | 60.2 |
Comparison of the performance of the TabNet Baseline model with oversampling methods.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 72.8 | 55.1 | 72.1 | 76.0 | 43.2 | |
| TabNet Baseline + SMOTE | 79.7 | 80.7 | 79.6 | 85.5 | 76.3 |
| TabNet Baseline + ADASYN | 83.6 | 85.4 | 83.9 | 88.8 | 82.3 |
Comparison of the performance of the TabNet Best model with oversampling methods.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 72.8 | 55.1 | 72.1 | 76.0 | 43.2 | |
| TabNet Best + SMOTE | 82.1 | 83.1 | 82.0 | 89.2 | 77.9 |
| TabNet Best + ADASYN | 87.6 | 89.5 | 88.1 | 95.5 | 84.2 |
Varying width of decision prediction layer (nd).
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 84.4 | 61.6 | 85.3 | 70.6 | 52.2 | |
| TabNet Baseline (default = 8) | 90.4 | 89.5 | 89.7 | 97.5 | 82.3 |
| TabNet with nd = 2 | 81.2 | 78.9 | 81.7 | 75.9 | 82.2 |
| TabNet with nd = 4 | 83.3 | 81.8 | 83.4 | 82.3 | 81.3 |
| TabNet with nd = 16 | 90.3 | 89.4 | 89.7 | 96.2 | 83.5 |
| TabNet with nd = 32 | 83.9 | 86.6 | 84.5 | 94.4 | 80 |
| TabNet with nd = 64 | 84.8 | 86.5 | 85.1 | 89.9 | 83.3 |
Varying the number of steps in the architecture (nsteps).
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 84.4 | 61.6 | 85.3 | 70.6 | 52.2 | |
| TabNet Baseline (default nsteps = 3) | 90.4 | 89.5 | 89.7 | 97.5 | 82.3 |
| TabNet with nsteps = 4 | 85.9 | 84.7 | 85.7 | 87.3 | 82.1 |
| TabNet with nsteps = 6 | 87.8 | 86.6 | 88.0 | 86.1 | 87.2 |
| TabNet with nsteps = 10 | 73.0 | 70.8 | 73.1 | 72.2 | 69.5 |
Varying gamma.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 84.4 | 61.6 | 85.3 | 70.6 | 52.2 | |
| TabNet Baseline (default gamma = 1.3) | 90.4 | 89.5 | 89.7 | 97.5 | 82.3 |
| TabNet with gamma = 1.5 | 87.7 | 86.5 | 88.0 | 84.8 | 88.2 |
| TabNet with gamma = 1.7 | 92.6 | 91.8 | 92.0 | 98.7 | 85.7 |
| TabNet with gamma = 1.9 | 90.0 | 89.2 | 89.7 | 93.7 | 85.1 |
Varying the number of independent gates (nindependent).
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 84.4 | 61.6 | 85.3 | 70.6 | 52.2 | |
| TabNet Baseline (nindependent = 2) | 90.4 | 89.5 | 89.7 | 97.5 | 82.3 |
| TabNet with nindependent = 3 | 86.5 | 85.7 | 85.7 | 94.9 | 78.2 |
| TabNet with nindependent = 4 | 82.8 | 83.4 | 82.7 | 82.0 | 84.9 |
| TabNet with nindependent = 5 | 85.6 | 84.3 | 85.7 | 84.8 | 83.8 |
Varying the number of shared gates (nshared).
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 84.4 | 61.6 | 85.3 | 70.6 | 52.2 | |
| TabNet Baseline (nshared = 2) | 90.4 | 89.5 | 89.7 | 97.5 | 82.3 |
| TabNet with nshared = 3 | 88.9 | 87.7 | 89.1 | 86.1 | 89.5 |
| TabNet with nshared = 4 | 87.5 | 86.4 | 87.4 | 88.6 | 84.3 |
| TabNet with nshared = 5 | 90.2 | 89.3 | 90.3 | 89.9 | 88.8 |
Varying momentum.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 84.4 | 61.6 | 85.3 | 70.6 | 52.2 | |
| T Li, Xiaoran et al. (baseline) | 84.4 | 61.6 | 85.3 | 70.6 | 52.2 |
| TabNet Baseline (momentum = 0.02) | 90.4 | 89.5 | 89.7 | 97.5 | 82.3 |
| TabNet with momentum = 0.1 | 89.9 | 89.0 | 89.1 | 97.5 | 81.9 |
| TabNet with momentum = 0.2 | 87.1 | 86.2 | 86.3 | 94.5 | 78.9 |
| TabNet with momentum = 0.3 | 88.8 | 86.0 | 88.0 | 97.5 | 80.2 |
Varying lambda sparse.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 84.4 | 61.6 | 85.3 | 70.6 | 52.2 | |
| TabNet Baseline (lambda sparse = 1e−3) | 90.4 | 89.5 | 89.7 | 97.5 | 82.3 |
| TabNet with lambda sparse = 1e−2 | 88.5 | 87.6 | 88.0 | 93.7 | 82.2 |
| TabNet with lambda sparse = 1e−1 | 88.6 | 87.5 | 88.6 | 88.6 | 86.4 |
Varying mask type.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 84.4 | 61.6 | 85.3 | 70.6 | 52.2 | |
| TabNet Baseline (masktype = sparsemax) | 90.4 | 89.5 | 89.7 | 97.5 | 82.3 |
| TabNet with mask type = entmax | 87.0 | 86.3 | 85.7 | 99.8 | 76.0 |
Impact of number of epochs and stopping condition on the performance of the TabNet architecture.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 84.4 | 61.6 | 85.3 | 70.6 | 52.2 | |
| TabNet Baseline (epoch = 100) | 90.4 | 89.5 | 89.7 | 97.5 | 82.3 |
| TabNet Baseline (epoch = 50) | 85.9 | 85.1 | 85.1 | 93.7 | 77.9 |
| TabNet Baseline (epoch = 150) | 90.4 | 89.5 | 89.7 | 97.5 | 82.3 |
| TabNet Baseline (epoch = 200) | 90.4 | 89.5 | 89.7 | 97.5 | 82.3 |
| TabNet Baseline maximum epoch = 150 with early stopping, patience = 5 | 85.6 | 84.5 | 85.1 | 89.9 | 78.0 |
| TabNet Baseline maximum epoch = 150 with early stopping, patience = 15 | 85.5 | 84.7 | 84.6 | 94.5 | 76.5 |
| TabNet Baseline epoch = 150 with early stopping, patience = 30 | 90.4 | 89.5 | 89.7 | 97.5 | 82.3 |
| TabNet Baseline epoch = 150 with early stopping, patience = 60 | 91.2 | 90.2 | 90.1 | 97.6 | 84.1 |
| TabNet Baseline epoch = 150 with early stopping, patience = 90 | 91.2 | 90.2 | 90.1 | 97.6 | 84.1 |
| TabNet Baseline epoch = 150 with early stopping, patience = 120 | 91.2 | 90.2 | 90.1 | 97.6 | 84.1 |
Comparison of the prediction performance of our baseline models with different variations of dimensionality reduction methods.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 84.4 | 61.6 | 85.3 | 70.6 | 52.2 | |
| TabNet Baseline + PCA | 94.0 | 93.3 | 93.7 | 97.5 | 89.5 |
| TabNet Baseline + Fast ICA | 93.8 | 92.9 | 93.1 | 99.7 | 86.8 |
| TabNet Baseline + Factor Analysis | 90.4 | 89.5 | 89.7 | 97.5 | 82.3 |
| TabNet Baseline + tSNE | 72.7 | 73.2 | 71.4 | 86.1 | 63.6 |
| TabNet Baseline + UMAP | 69.2 | 68.6 | 68.6 | 75.9 | 62.5 |
Comparison of the prediction performance of our best TabNet models with different variations of dimensionality reduction methods.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 84.4 | 61.6 | 85.3 | 70.6 | 52.2 | |
| TabNet Best + PCA | 94.2 | 93.4 | 93.7 | 98.7 | 88.6 |
| TabNet Best + Fast ICA | 95.3 | 94.6 | 94.9 | 99.8 | 89.8 |
| TabNet Best + Factor Analysis | 93.1 | 92.3 | 92.3 | 98.7 | 86.7 |
| TabNet Best + tSNE | 73.7 | 73.6 | 72.6 | 84.8 | 65.0 |
| TabNet Best + UMAP | 66.7 | 69.0 | 64.6 | 88.6 | 56.9 |
Comparison of the performance of the TabNet Baseline model with oversampling methods.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 84.4 | 61.6 | 85.3 | 70.6 | 52.2 | |
| TabNet Baseline + SMOTE | 93.2 | 92.3 | 92.1 | 97.2 | 88.7 |
| TabNet Baseline + ADASYN | 93.8 | 92.9 | 93.1 | 99.3 | 86.8 |
Comparison of the performance of the TabNet Best model with oversampling methods.
| Model | AUC | F1Score | Accuracy | Recall | Precision |
|---|---|---|---|---|---|
| 84.4 | 61.6 | 85.3 | 70.6 | 52.2 | |
| TabNet Best + SMOTE | 94.3 | 94.6 | 94.3 | 98.8 | 90.6 |
| TabNet Best + ADASYN | 96.3 | 95.8 | 96.0 | 99.8 | 91.8 |