Literature DB >> 36059892

Subpopulation-specific machine learning prognosis for underrepresented patients with double prioritized bias correction.

Sharmin Afrose1, Wenjia Song1, Charles B Nemeroff2, Chang Lu3, Danfeng Daphne Yao1.   

Abstract

Background: Many clinical datasets are intrinsically imbalanced, dominated by overwhelming majority groups. Off-the-shelf machine learning models that optimize the prognosis of majority patient types (e.g., healthy class) may cause substantial errors on the minority prediction class (e.g., disease class) and demographic subgroups (e.g., Black or young patients). In the typical one-machine-learning-model-fits-all paradigm, racial and age disparities are likely to exist, but unreported. In addition, some widely used whole-population metrics give misleading results.
Methods: We design a double prioritized (DP) bias correction technique to mitigate representational biases in machine learning-based prognosis. Our method trains customized machine learning models for specific ethnicity or age groups, a substantial departure from the one-model-predicts-all convention. We compare with other sampling and reweighting techniques in mortality and cancer survivability prediction tasks.
Results: We first provide empirical evidence showing various prediction deficiencies in a typical machine learning setting without bias correction. For example, missed death cases are 3.14 times higher than missed survival cases for mortality prediction. Then, we show DP consistently boosts the minority class recall for underrepresented groups, by up to 38.0%. DP also reduces relative disparities across race and age groups, e.g., up to 88.0% better than the 8 existing sampling solutions in terms of the relative disparity of minority class recall. Cross-race and cross-age-group evaluation also suggests the need for subpopulation-specific machine learning models. Conclusions: Biases exist in the widely accepted one-machine-learning-model-fits-all-population approach. We invent a bias correction method that produces specialized machine learning prognostication models for underrepresented racial and age groups. This technique may reduce potentially life-threatening prediction mistakes for minority populations.
© The Author(s) 2022.

Entities:  

Keywords:  Cancer; Prognosis

Year:  2022        PMID: 36059892      PMCID: PMC9436942          DOI: 10.1038/s43856-022-00165-w

Source DB:  PubMed          Journal:  Commun Med (Lond)        ISSN: 2730-664X


Introduction

Researchers have trained machine learning models to predict many diseases and conditions, including Alzheimer’s disease[1], heart disease[2], risk of developing diabetic retinopathy[3], cancer risk[4] and survivability[5], genetic testing for diseases[6], hypertrophic cardiomyopathy diagnosis[7], psychosis[8], posttraumatic stress disorder (PTSD)[9], and COVID–19[10]. Neural network-powered automatic image analysis has also been shown useful for fast disease detection, e.g., breast cancer[11] and lung cancer[12]. A study showed that deep learning algorithms diagnose breast cancer more accurately (AUC = 0.994) than 11 pathologists[11]. Hospitals (e.g., Cleveland Clinic partnering with Microsoft[13], Johns Hopkins Hospital partnering with GE Healthcare)[14] are reported to use predictive analytics for monitoring patients’ health status and preventing emergencies[15-18]. However, clinical datasets are intrinsically imbalanced due to the naturally occurring frequencies of data[19]. The data is not evenly distributed across prediction classes (e.g., disease class vs. healthy class), race, age, or other subgroups. Data imbalance is a major cause of biased prediction results[19]. Biased prediction results may have serious consequences for some patients. For example, a recent study showed that automatic enrollment of high–risk patients into the health program favors white patients, although Black patients had 26.3% more chronic health conditions than equally ranked white patients[20]. Similarly, algorithmic osteoarthritis pain prediction shows 43% racial disparities[21]. The design of widely used case-control studies is shown to have a temporal bias that reduces predictive accuracy[22]. For non–medical applications, researchers also identified serious biases in high–profile machine learning applications, e.g., a widely deployed recidivism prediction tool[23-25], online advertisement system[26], Amazon’s recruiting engine[27], and face recognition system[28]. The lack of external validation and overclaiming causal effect in machine learning also raise concerns[29]. A widely used bias-correction approach to the data imbalance problem is sampling. Oversampling, e.g., replicated oversampling (ROS), is to balance a dataset by adding samples of the minority class; undersampling, e.g., random undersampling (RUS), is to balance a dataset by removing samples of the majority class[30]. An improvement is the K–nearest neighbor (K–NN) classifier–based undersampling technique[31] (e.g., NearMiss1, NearMiss2, NearMiss3, Distant) that selects samples from the majority class based on distance from minority class samples. State-of-the-art solutions are all oversampling methods, including Synthetic Minority Oversampling Technique (SMOTE)[32], Adaptive Synthetic Sampling (ADASYN)[33], and Gamma[34]. All three methods generate new minority points based on existing minority samples, namely using linear interpolation[32], gamma distribution[34], or at the class border[33]. However, existing sampling techniques are not designed to address subgroup biases, as they sample the entire minority class. These methods do not differentiate demographic subgroups (e.g., Black patients or young patients under 30). Thus, it is unclear how well existing sampling solutions reduce accuracy disparity. We present two categories of contributions to machine learning prognosis for underrepresented patients. One contribution is empirical evidence showing severe racial and age prediction disparities and the deceptive nature of some common metrics. Another contribution is on evaluating the bias-correction ability of sampling methods, including a new double prioritized (DP) bias correction technique. In our first contribution, we use two large medical datasets (MIMIC III and SEER) to show multiple types of prediction deficiencies, some due to the choice of metrics. Poor prediction performance in minority samples is not reflected in widely used whole-population metrics. For imbalanced datasets, conventional metrics such as overall accuracy and AUC–ROC are largely influenced by the performance of the majority of samples, which machine learning models aim to fit. Unfortunately, this serious deficiency is not well discussed or reported by medical literature. For example, a study showed that 66.7% of the 33 medical-related machine learning papers used AUC–ROC to evaluate models trained on imbalanced datasets[35]. In our second contribution, we present a new technique, double prioritized (DP) bias correction, that aims to improve the prediction accuracy of specific demographic groups through sample enrichment. DP trains customized prediction models for specific subpopulations, a departure from the existing one-model-predicts-all-demographics paradigm. DP prioritizes specific underrepresented groups, as opposed to sampling across the entire patient population. From our experiments, we report racial, age, and metric disparities in machine learning models trained on clinical prediction benchmark[17] on MIMIC III and cancer survival prediction[5] on the SEER cancer dataset. Both training datasets are imbalanced in terms of race and age distributions. For example, for the in-hospital mortality (IHM) prediction with MIMIC III, 70.6% of data represents white patients, whereas only 9.6% represents Black patients. MIMIC III and SEER also have data imbalance problems among the two class labels (e.g., death vs. survival). For the IHM prediction, only 13.5% of data belongs to the patient who died in the hospital. These data imbalances result in serious prediction biases. A typical neural network-based machine learning model[17] that we tested correctly predicts 87.6% of non-death cases but only 60.9% of death cases. Meanwhile, overall accuracy (computed over all patients) is relatively high (0.85), and AUC–ROC is 0.86 because of the good performance in the majority class. These high scores are misleading. Our study also reveals that accuracy among age or race subgroups differs. For example, the mortality prediction precision (i.e., the fraction of actual deaths among predicted deaths) of young patients under 30 is 0.09, substantially lower than the whole population (0.40). Recognizing these accuracy challenges will help advance AI-based technologies to better serve underrepresented patients. Our results show that DP is effective in boosting the minority class recall for underrepresented groups by up to 38.0%. DP also reduces the disparity among age and race groups. For the in-hospital mortality (IHM) and 5-year breast cancer survivability (BCS) predictions, DP shows a 14.8% to 23.9% improvement over the original model and 5.6% to 88.0% improvement over eight existing sampling techniques for the relative disparity of minority class recall. Our cross-race and cross-age-group results also suggest the need for training specialized machine learning models for different demographic subgroups. All sampling techniques (including DP) are not designed to address biases caused by underdiagnosis, measurement, or any other sources of disparity besides data representation. In what follows, DP assumes that the noise is the same across all demographic subgroups and the only source of bias that it aims to correct is representational.

Methods

Double prioritized (DP) bias correction method

DP prioritizes a specific demographic subgroup (e.g., Black patients) that suffers from data imbalance by replicating minority prediction class (C1) cases from this group (e.g., Black in-hospital deaths). DP incrementally increases the number of duplicated units and chooses the optimal unit number based on the resulting models’ performance. Figure 1 shows the machine learning workflow with DP bias correction. The main steps are described next.
Fig. 1

Workflow for improving data balance in machine learning prognosis prediction using double prioritized (DP) bias correction.

Sample Enrichment prepares a number of new training datasets by incrementally enriching a specific demographic subgroup; Candidate Training is where each of the n + 1 datasets is used for training a candidate machine learning model; Model Selection identifies the optimal model; Prediction applies the selected model on new patient data. AUC-PR represents the area under the curve of the precision-recall curve.

Workflow for improving data balance in machine learning prognosis prediction using double prioritized (DP) bias correction.

Sample Enrichment prepares a number of new training datasets by incrementally enriching a specific demographic subgroup; Candidate Training is where each of the n + 1 datasets is used for training a candidate machine learning model; Model Selection identifies the optimal model; Prediction applies the selected model on new patient data. AUC-PR represents the area under the curve of the precision-recall curve. Sample Enrichment replicates minority class C1 samples in the training dataset for a target demographic group g up to n times. Each time, duplicated samples are merged with the original training dataset, which forms a new training dataset. Thus, we obtain n + 1 sets of training datasets, including the original one. Our experiment sets n to 19. The value n can be empirically determined based on prediction performance. Candidate Training is to generate a set of candidate machine learning models. Each of the n + 1 datasets is used to train and generate a candidate machine learning model. Two types of neural networks are used, the long short-term memory (LSTM) model and the multilayer perceptron (MLP) model. Following Harutyunyan et al.[17], for the hospital record prediction tasks, patients’ data is preprocessed into time-series records and fed into an LSTM model. Cancer survivability prediction utilizes an MLP model, following Hegselmann et al.[5] Prediction and data analysis code is in Python programming language. The hospital record prediction tasks were executed on a virtual machine with Ubuntu 18.04 operating system, x86-64 architecture, 8 cores, 40 GB RAM, and 1 GPU. Cancer survivability prediction tasks were performed using an Ubuntu 21.04 operating system, x86-64 architecture, 16 cores, 40 GB RAM, and 1 GPU. Model parameters remain constant in different bias correction techniques (Supplementary Table 1). Model Selection is to identify the optimal machine learning model among the n + 1 candidate models. We choose a final machine learning model, M*, after evaluating all candidate models’ performance as follows. For each model, we first calibrate the predicted probabilities on the validation set. Calibration is to adjust the distribution of probabilities before mapping probabilities into labels. We calibrate the output probabilities using the Isotonic Regression technique. We then perform threshold tuning to find the optimal threshold based on balanced accuracy and the F1_C1 score. Specifically, we first identify the top three thresholds that give the highest F1_C1 scores and then further select the optimal threshold that gives the highest balanced accuracy for all samples. For some subgroups, there are only a couple of hundreds of samples in the validation set. Selecting the threshold based on subgroup data may cause overfitting to the validation set. Therefore, we choose thresholds based on the whole group performances. Given a threshold, we then identify the top three machine learning models with the highest balanced accuracy (i.e., average recall of both C0 and C1 classes, Supplementary Equation 6) values and select the model that gives the highest PR_C1 (the area under the curve (AUC) of minority class C1’s precision-recall curve, denoted by AUC-PR_C1 or PR_C1) for demographic group g. In this step, no enrichment is applied to the validation dataset. When deciding thresholds, AUC-PR cannot be used, as it is a threshold-free metric. Thus, we use balanced accuracy and F1_C1. Prediction applies model M* to new patients’ records of minority group g′ and obtains a binary class label. At deployment, the demographic group g of duplicated samples during Sample Enrichment and test group g′ should be the same, e.g., the DP model trained with duplicated Black samples is used to predict new Black patients. Evaluation metrics include accuracy, balanced accuracy, Matthews Correlation Coefficient (MCC), AUC–ROC score, precision, recall, AUC-PR, and F1 score of minority and majority prediction classes, the whole population, and various demographic subgroups, including gender (male, female), race (white, Black, Hispanic, Asian), and 8 age groups. Minority class C1 precision calculates the fraction of actual minority C1 class cases among predicted ones. C1 recall calculates the fraction of C1 cases that are successfully predicted by a machine learning model. We use the relative disparity metric to capture the disparity among race groups or age groups. Equation (1) shows the equation for the relative disparity. All other metrics are defined in supplementary equations.where R is the highest and R is the lowest evaluation metric value being compared. Similar to other studies[34,36], our workflow does not sample the test dataset because the ground truth (i.e., new patient’s disease or health label) is unknown in the real world. Relative disparity values are greater than or equal to 1. MCC values are in the range of [−1, 1]. The other metric values are in the range of [0, 1]. When comparing datasets that have different percentages of minority class C1 samples, we avoid metrics (e.g., AUC-PR) whose baselines (i.e., the performance of a random classifier) depend on the C1 percentage[35].

Other bias correction techniques being compared

The eight existing sampling approaches being compared include four undersampling techniques (namely, random undersampling, NearMiss1, NearMiss3, distant method) and four oversampling techniques (namely, replicated oversampling, SMOTE, ADASYN, Gamma). Undersampling balances the distribution of the two prediction classes by selecting only a subset of the majority class cases. Oversampling balances the dataset by populating the minority class. We also use MLP models with different structures (i.e., different number of layers, different neurons per layer, and different dropout rates). Reweighting is an alternative bias correction approach to sampling[37,38]. The reweighting approach assigns different importance to samples in the training data, in order for some minority class samples to impact more on training outcomes. We compare DP with two methods, the standard reweighting method and a new prioritized reweighting method. Standard reweighting aims to make the weights of the two prediction classes balanced. In the standard reweighting approach, new weights are applied to the entire class population as follows. Reweight all samples so that each majority sample weights less than 1 and each minority sample weights more than 1, while satisfying the constraint that the total weight of each prediction class is equal. In our standard reweighting experiment, the minority class has a weight of 3.94 and the majority class has a weight of 0.57 for BCS prediction. The weights are 3.12 and 0.60 for the minority and majority classes, respectively, for LCS prediction.

Prioritized reweighting

Following our DP design, we also invent a new prioritized reweighting approach. Prioritized reweighting selectively reweights specific subgroup minority samples, as opposed to reweighting all minority class C1 samples as in the standard reweighting. In the new prioritized reweighting method, we dynamically reweight minority class samples of selected demographic subgroups and choose the optimal machine learning model using the same metrics and procedure as in DP. Specifically, in each round of prioritized reweighting experiments, we multiply the selected samples’ default weight by a unit number n, where n ranges from 1 to 20. The weights of samples in other subgroups and majority class samples in the selected subgroup remain the default value, i.e., 1. These weights are used to train a machine learning model. Once the n machine learning models are trained, we follow DP’s Model Selection operation for calibration and threshold selection.

Cross-racial-group and cross-age-group experiments

We also perform a series of cross-group experiments, where enriched samples and test samples are from different demographic groups, i.e., group g used for Sample Enrichment and test group g′ are different. The purpose is to assess the impact of different machine learning models on prediction outcomes.

Whole-group vs. subgroup-based threshold tuning

When analyzing the performance of the original model without bias correction, we evaluate two different settings. The first setting is to select an optimal threshold based on all samples in the validation set. We refer to the selected threshold as the whole group threshold. The second setting is to select an optimal threshold for each demographic subgroup based on that specific subgroup’s performance in the validation set. We refer to the selected thresholds as the subgroup thresholds. In both settings, we calibrate the prediction on all samples (i.e., whole group) and select the thresholds with the top 3 highest F1_C1 scores and choose the one with the best-balanced accuracy.

SHAP-sum and SHAP-avg feature importance

We calculate the feature importance for all four tasks (i.e., IHM, Decompensation, BCS, and LCS) using the Shapley Additive exPlanations (SHAP). For one-hot encoded categorical variables, each of them is represented by multiple columns in the input data. SHAP is not designed for such one-hot encoded categorical features. The standard SHAP method calculates the importance of each column. Thus, we have to post-process the importance of these features. We implement two approaches, SHAP-avg and SHAP-sum. In the SHAP-avg approach, we compute the average importance of columns representing the same feature, i.e., the importance of columns representing the same variable is averaged. In the SHAP-sum approach, we add up the importance of all columns representing the same feature.

Clinical datasets

We use MIMIC III[17,39] and SEER[40] cancer datasets, both collected in the US. We test existing machine learning models in a clinical prediction benchmark[17] for MIMIC III and cancer survival prediction[5] for SEER. We study a total of four binary classification tasks, in-hospital mortality (IHM) prediction and decompensation prediction from the clinical prediction benchmark[17], 5-year breast cancer survivability (BCS) prediction, and 5-year lung cancer survivability (LCS) prediction. In what follows, we denote the minority prediction class as Class 1 (or C1) and the majority class as Class 0 (or C0). Figure 2a–d shows the composition of IHM training data, which contains 14,681 time-series samples from MIMIC III. The majority of the records (86.5%) belong to Class 0 (i.e., patients who do not die in hospital). The rest (13.5%) belong to Class 1 (i.e., the patients who die in the hospital). The percentage of Class 1 samples within each subgroup slightly varies but is consistently low. 70.6% of the patients are white and 76% belong to the age range [50, 90). In our [X, Y) age range notation, the square opening bracket means the beginning value X is included; the round closing bracket means the ending value Y is excluded. 45.1% of the patients are females and 54.9% are males. The training set contains insufficient data for the young adult population. Distributions of the decompensation training dataset (of size 2,377,768) are similar (Supplementary Fig. 1a–d). Figure 2e–h shows the percentages of different subgroup sizes for the training dataset used in BCS prediction. The BCS training set contains 199,000 samples, of which 87.3% are in Class 0 (i.e., patients diagnosed with breast cancer and survived for more than 5 years) and 0.6% are males. The percentage of Class 1 samples is low in most groups, with an exception of the age 90+ subgroup, which has a high mortality rate. The majority race group (81%) is white. When categorized by age, 70% of the patients are between 40 and 70. The LCS training dataset (of size 164,443) follows similar imbalanced distributions (Supplementary Fig. 1e–h).
Fig. 2

Recall values for both classes C0 and C1 and training data statistics for the in-hospital mortality (IHM) and the 5-year breast cancer survivability (BCS) tasks.

a Percentage of the minority class C1, Recall C0, and Recall C1 of each subgroup of the MIMIC dataset for the IHM task. Statistics of b prediction class distribution, c racial group distribution, and d age group distribution for the MIMIC IHM dataset. The MIMIC IHM training set consists of 45.1% female samples and 54.8% male samples. e Percentage of the minority class C1, Recall C0, and Recall C1 of each subgroup of the SEER dataset for the BCS task. Statistics of f prediction class distribution, g racial group distribution, and h age group distribution for the SEER BCS dataset. The SEER BCS training set consists of 99.4% female samples and 0.6% male samples.

Recall values for both classes C0 and C1 and training data statistics for the in-hospital mortality (IHM) and the 5-year breast cancer survivability (BCS) tasks.

a Percentage of the minority class C1, Recall C0, and Recall C1 of each subgroup of the MIMIC dataset for the IHM task. Statistics of b prediction class distribution, c racial group distribution, and d age group distribution for the MIMIC IHM dataset. The MIMIC IHM training set consists of 45.1% female samples and 54.8% male samples. e Percentage of the minority class C1, Recall C0, and Recall C1 of each subgroup of the SEER dataset for the BCS task. Statistics of f prediction class distribution, g racial group distribution, and h age group distribution for the SEER BCS dataset. The SEER BCS training set consists of 99.4% female samples and 0.6% male samples. To compute standard deviations, we repeat the machine learning training process multiple times, each time producing a machine learning model. Specifically, for BCS and LCS prediction tasks, we repeat the experiments five times. For the in-hospital mortality task, we repeat the experiments three times. Under these settings, average values and standard deviations are computed for all results except SHAP. Tables only show average results without error bars. All SHAP feature importance results (in the Supplementary Section) are based on the performance of a randomly selected machine learning model. For the decompensation prediction task, due to its high time complexity, we run the experiments once. We use publicly available clinical datasets, which does not meet the criteria for human subjects research. Thus, ethical approval is not required for this study.
  21 in total

1.  Analysis of sampling techniques for imbalanced data: An n = 648 ADNI study.

Authors:  Rashmi Dubey; Jiayu Zhou; Yalin Wang; Paul M Thompson; Jieping Ye
Journal:  Neuroimage       Date:  2013-10-29       Impact factor: 6.556

2.  Quantitative forecasting of PTSD from early trauma responses: a Machine Learning application.

Authors:  Isaac R Galatzer-Levy; Karen-Inge Karstoft; Alexander Statnikov; Arieh Y Shalev
Journal:  J Psychiatr Res       Date:  2014-09-16       Impact factor: 4.791

3.  Dynamic ElecTronic hEalth reCord deTection (DETECT) of individuals at risk of a first episode of psychosis: a case-control development and validation study.

Authors:  Lars Lau Raket; Jörn Jaskolowski; Bruce J Kinon; Jens Christian Brasen; Linus Jönsson; Allan Wehnert; Paolo Fusar-Poli
Journal:  Lancet Digit Health       Date:  2020-03-26

4.  Early hospital mortality prediction of intensive care unit patients using an ensemble learning approach.

Authors:  Aya Awad; Mohamed Bader-El-Den; James McNicholas; Jim Briggs
Journal:  Int J Med Inform       Date:  2017-10-05       Impact factor: 4.046

5.  Predicting the risk of developing diabetic retinopathy using deep learning.

Authors:  Ashish Bora; Siva Balasubramanian; Boris Babenko; Sunny Virmani; Subhashini Venugopalan; Akinori Mitani; Guilherme de Oliveira Marinho; Jorge Cuadros; Paisan Ruamviboonsuk; Greg S Corrado; Lily Peng; Dale R Webster; Avinash V Varadarajan; Naama Hammel; Yun Liu; Pinal Bavishi
Journal:  Lancet Digit Health       Date:  2020-11-26

6.  The precision-recall plot is more informative than the ROC plot when evaluating binary classifiers on imbalanced datasets.

Authors:  Takaya Saito; Marc Rehmsmeier
Journal:  PLoS One       Date:  2015-03-04       Impact factor: 3.240

7.  Risk prediction models for selection of lung cancer screening candidates: A retrospective validation study.

Authors:  Kevin Ten Haaf; Jihyoun Jeon; Martin C Tammemägi; Summer S Han; Chung Yin Kong; Sylvia K Plevritis; Eric J Feuer; Harry J de Koning; Ewout W Steyerberg; Rafael Meza
Journal:  PLoS Med       Date:  2017-04-04       Impact factor: 11.069

8.  Differences in youngest-old, middle-old, and oldest-old patients who visit the emergency department.

Authors:  Sang Bum Lee; Jae Hun Oh; Jeong Ho Park; Seung Pill Choi; Jung Hee Wee
Journal:  Clin Exp Emerg Med       Date:  2018-12-31

9.  MIMIC-III, a freely accessible critical care database.

Authors:  Alistair E W Johnson; Tom J Pollard; Lu Shen; Li-Wei H Lehman; Mengling Feng; Mohammad Ghassemi; Benjamin Moody; Peter Szolovits; Leo Anthony Celi; Roger G Mark
Journal:  Sci Data       Date:  2016-05-24       Impact factor: 6.444

10.  The accuracy, fairness, and limits of predicting recidivism.

Authors:  Julia Dressel; Hany Farid
Journal:  Sci Adv       Date:  2018-01-17       Impact factor: 14.136

View more

北京卡尤迪生物科技股份有限公司 © 2022-2023.