| Literature DB >> 32235884 |
Fengyi Tang1,2, Ikechukwu Uchendu1, Fei Wang3, Hiroko H Dodge4, Jiayu Zhou5.
Abstract
The search for early biomarkers of mild cognitive impairment (MCI) has been central to the Alzheimer's Disease (AD) and dementia research community in recent years. To identify MCI status at the earliest possible point, recent studies have shown that linguistic markers such as word choice, utterance and sentence structures can potentially serve as preclinical behavioral markers. Here we present an adaptive dialogue algorithm (an AI-enabled dialogue agent) to identify sequences of questions (a dialogue policy) that distinguish MCI from normal (NL) cognitive status. Our AI agent adapts its questioning strategy based on the user's previous responses to reach an individualized conversational strategy per user. Because the AI agent is adaptive and scales favorably with additional data, our method provides a potential avenue for large-scale preclinical screening of neurocognitive decline as a new digital biomarker, as well as longitudinal tracking of aging patterns in the outpatient setting.Entities:
Mesh:
Substances:
Year: 2020 PMID: 32235884 PMCID: PMC7109153 DOI: 10.1038/s41598-020-61994-0
Source DB: PubMed Journal: Sci Rep ISSN: 2045-2322 Impact factor: 4.379
Figure 1Feedback loop of the reinforcement learning environment for training the MCI diagnosis agent. The user simulator trained from the original dialogue corpus is used to generate simulated user response to new questions from the MCI diagnosis agent (i.e., the Reinforcement Learning Agent). At each conversational turn, the “user state” of the simulated patient is updated based on the questions asked by the MCI diagnosis agent. We designed a Dialogue Manager which produces a reward signal to the MCI diagnosis agent based on the quality of questions asked.
Figure 2Overview of proposed algorithm for conversational generation and linguistic marker identification using a RL pipeline. Supervised learning pipeline denotes the classical approach by Asgari et al.[9]. Our approach is summarized in the RL pipeline and involves a feedback loop with the MCI diagnosis agent generating questions to new users for the purposes of predicting their MCI status using a trained ML classifier.
Classification of MCI based on complete transcript vs. simulated conversations.
| Model | AUC | F1-Score | Sensitivity | Specificity |
|---|---|---|---|---|
| SVM | 0.712 (0.612–0.811) | 0.631 (0.500–0.761) | 0.680 (0.476–0.886) | 0.744 (0.563–0.922) |
| Supervised DL | 0.689 (0.560–0.818) | 0.182 (0.055–0.370) | 0.300 (0.010–0.758) | 0.767 (0.364–0.970) |
| SVM | 0.797 (0.719–0.879) | 0.719 (0.591–0.846) | 0.654 (0.473–0.835) | 0.939 (0.855–1.0) |
| Supervised DL | 0.811 (0.715–0.907) | 0.642 (0.469–0.813) | 0.600 (0.366–0.833) | 0.911 (0.838–0.984) |
| RL (T = 5) | 0.633 (0.535–0.703) | 0.486 (0.288–0.680) | 0.459 (0.280–0.630) | 0.811 (0.661–0.936) |
| RL (T = 10) | 0.741 (0.631–0.852) | 0.590 (0.352–0.829) | 0.560 (0.309–0.811) | 0.922 (0.823–0.969) |
| RL (T = 15) | 0.721 (0.618–0.827) | 0.595 (0.399–0.790) | 0.50 (0.327–0.713) | 0.922 (0.856–0.987) |
| RL (T = 20) | 0.809 (0.706–0.914) | 0.726 (0.551–0.901) | 0.620 (0.413–0.827) | 0.988 (0.953–1.0) |
| RL (T = 30) | 0.853 (0.796–0.914) | 0.801 (0.733–0.880) | 0.818 (0.678–0.958) | 0.898 (0.828–0.969) |
| RL(T = 35) | 0.859 (0.787–0.952) | 0.808 (0.735–0.883) | 0.818 (0.677–0.958) | 0.911 (0.839–1.0) |
| Difference | 0.0616 (−0.049–0.172) | 0.089 (−0.078–0.259) | 0.163 (−0.083–0.410) | −0.040 (−0.130–0.050) |
Abbreviations: Parentheses denotes confidence interval (CI) for the metric. SVM denotes support vector machines classifier, and Supervised DL denotes 2-layer feed-forward neural network classifier. RL denotes reinforcement learning agent. For feature representation of corpus, LIWC is the original word-level embedding used in Asgari et al., 8. SKP denotes a 4800-dimensional Skip-Thought vector embedding was used to represent each conversational turn. A dialogue summary is obtained by averaging across all turn-based responses for each user. We then evaluate the performance of our RL-agent across 10 stratified shuffle splits. Each split uses 65% of data for training and 35% for testing.
Figure 3Conversational efficiency of AI agents. The x-axis represents the number of dialogue turns elapsed. The y-axis represents various performance metrics. Baseline refers to the performance of MCI classifier using all the responses generated from the original dataset. By contrast, RL refers to the performance of MCI classifier using responses generated by the user simulator, in response to the agent-generated questions at test time.
MCI prediction of transcript and simulated conversations with turn restrictions.
| Model (Turns) | AUC | F1-Score | Sensitivity | Specificity |
|---|---|---|---|---|
| SVM (T = 5) | 0.493 (0.439–0.547) | 0.169 (0.061–0.275) | 0.12 (0.046–0.193) | 0.860 (0.776–0.950) |
| SVM (T = 10) | 0.550 (0.479–0.620) | 0.275 (0.113–0.428) | 0.200 (0.083–0.319) | 0.900 (0.820–0.970) |
| SVM (T = 20) | 0.624 (0.563–0.685) | 0.405 (0.232–0.578) | 0.360 (0.171–0.548) | 0.888 (0.789–0.989) |
| SVM (T = 30) | 0.633 (0.557–0.707) | 0.424 (0.247–0.601) | 0.320 (0.187–0.458) | 0.944 (0.882–1.0) |
| SVM (T = 35) | 0.714 (0.627–0.801) | 0.576 (0.420–0.732) | 0.440 (0.277–0.602) | 0.968 (0.944–1.0) |
| Supervised DL (T = 5) | 0.497 (0.392–0.603) | 0.104 (0.015–0.182) | 0.111 (0.091–0.129) | 0.880 (0.812–0.980) |
| Supervised DL (T = 10) | 0.527 (0.459–0.594) | 0.278 (0.123–0.433) | 0.200 (0.088–0.316) | 0.933 (0.856–1.0) |
| Supervised DL (T = 20) | 0.673 (0.588–0.758) | 0.399 (0.212–0.583) | 0.320 (0.139–0.500) | 0.945 (0.888 - (1.0) |
| Supervised DL (T = 30) | 0.720 (0.643–0.796) | 0.477 (0.317–0.638) | 0.360 (0.228–0.491) | 0.955 (0.914–0.996) |
| Supervised DL (T = 35) | 0.780 (0.695–0.864) | 0.490 (0.327–0.654) | 0.366 (0.229–0.500) | 0.966 (0.928–1.0) |
| RL (T = 5) | 0.633 (0.535–0.703) | 0.486 (0.288–0.680) | 0.459 (0.280–0.630) | 0.811 (0.661–0.936) |
| RL(T = 35) | 0.859 (0.787–0.952) | 0.808 (0.735–0.883) | 0.818 (0.677–0.958) | 0.911 (0.839–1.0) |
Abbreviations: SVM denotes support vector machines classifier, Supervised DL denotes 2-layer feedforward neural network, and RL denotes reinforcement learning agent. In all three cases, conversations were cut off at various turn lengths (T), and performance with the classifier was performed to obtain the AUC, F1, sensitivity and specificity scores. Confidence intervals were obtained on 10 randomized shuffle splits for all experiments.
Comparison of AI and interviewer strategies using off-policy evaluation.
| Policy | Avg. Reward/Turn | WIS Score | DR estim./turn | DR Score |
|---|---|---|---|---|
| RL (T = 35) | 11.68 (2.06–21.35) | 408.29 (72.41–744.17) | 13.10 (12.91–13.35) | 458.64 (452.40–464.87) |
| Expert Policy | 2.62 (−7.28–12.51) | 91.71 (−255.12 − 438.68) | 10.82 (10.51–11.14) | 379.07 (367.89–390.25) |
| Advantage | 8.68 (7.16–10.13) | 302.67 (250.78–354.58) | — | — |
Weighted Importance Sampling (WIS) indicates off-policy evaluation of a given policy while sampling trajectories from the original dataset corpus25. For the expert policy, no importance weights are needed, and the cumulative rewards are used over entire conversational episodes. For the AI agent, a cut-off of 35 turns is again used to bound the length of off-policy trajectories. Average reward per turn is used to assess the average expected reward for the agent based on the reward function used to train the RL agent.