| Literature DB >> 35440643 |
Chuhan Wu1, Fangzhao Wu2, Lingjuan Lyu3, Yongfeng Huang4, Xing Xie5.
Abstract
Federated learning is a privacy-preserving machine learning technique to train intelligent models from decentralized data, which enables exploiting private data by communicating local model updates in each iteration of model learning rather than the raw data. However, model updates can be extremely large if they contain numerous parameters, and many rounds of communication are needed for model training. The huge communication cost in federated learning leads to heavy overheads on clients and high environmental burdens. Here, we present a federated learning method named FedKD that is both communication-efficient and effective, based on adaptive mutual knowledge distillation and dynamic gradient compression techniques. FedKD is validated on three different scenarios that need privacy protection, showing that it maximally can reduce 94.89% of communication cost and achieve competitive results with centralized model learning. FedKD provides a potential to efficiently deploy privacy-preserving intelligent systems in many scenarios, such as intelligent healthcare and personalization.Entities:
Mesh:
Year: 2022 PMID: 35440643 PMCID: PMC9018897 DOI: 10.1038/s41467-022-29763-x
Source DB: PubMed Journal: Nat Commun ISSN: 2041-1723 Impact factor: 17.694
Performance (with standard deviations) and communication cost per client of different methods on MIND and ADR.
| Methods | MIND | ADR | |||||||
|---|---|---|---|---|---|---|---|---|---|
| AUC | MRR | nDCG@5 | nDCG@10 | Comm. cost per client | Precision | Recall | Fscore | Comm. cost per client | |
| UniLM (Local) | 68.8 ± 0.5 | 33.5 ± 0.4 | 36.6 ± 0.5 | 42.4 ± 0.6 | – | 53.2 ± 1.3 | 54.6 ± 1.4 | 53.9 ± 1.1 | – |
| UniLM (Cen) | 71.0 ± 0.1 | 35.8 ± 0.1 | 39.0 ± 0.1 | 44.8 ± 0.1 | – | 60.3 ± 0.7 | 61.6 ± 0.8 | 60.8 ± 0.4 | – |
| UniLM (Fed) | 70.9 ± 0.3 | 35.7 ± 0.2 | 38.9 ± 0.3 | 44.7 ± 0.4 | 2.05GB (1.0×) | 59.1 ± 0.6 | 62.3 ± 0.6 | 60.6 ± 0.4 | 1.37GB (1.0×) |
| DistilBERT6 | 69.3 ± 0.2 | 34.0 ± 0.2 | 37.5 ± 0.2 | 43.0 ± 0.1 | 1.03GB (2.0×) | 56.8 ± 0.8 | 59.2 ± 0.8 | 57.9 ± 0.5 | 0.69GB (2.0×) |
| DistilBERT4 | 69.0 ± 0.2 | 33.7 ± 0.1 | 37.0 ± 0.1 | 42.6 ± 0.2 | 0.69GB (3.0×) | 56.5 ± 0.9 | 58.4 ± 1.1 | 57.1 ± 0.7 | 0.46GB (3.0×) |
| BERT-PKD6 | 69.6 ± 0.2 | 34.4 ± 0.3 | 37.7 ± 0.3 | 43.4 ± 0.2 | 1.03GB (2.0×) | 56.9 ± 0.9 | 60.4 ± 0.8 | 58.4 ± 0.6 | 0.69GB (2.0×) |
| BERT-PKD4 | 69.2 ± 0.2 | 33.8 ± 0.2 | 37.1 ± 0.3 | 42.9 ± 0.3 | 0.69GB (3.0×) | 56.3 ± 1.1 | 59.9 ± 0.7 | 58.0 ± 0.6 | 0.46GB (3.0×) |
| TinyBERT6 | 69.7 ± 0.2 | 34.5 ± 0.2 | 37.9 ± 0.1 | 43.5 ± 0.2 | 1.03GB (2.0×) | 57.4 ± 0.8 | 60.5 ± 0.6 | 58.6 ± 0.5 | 0.69GB (2.0×) |
| TinyBERT4 | 69.4 ± 0.3 | 33.9 ± 0.3 | 37.5 ± 0.2 | 43.1 ± 0.2 | 0.17GB ( | 57.0 ± 0.7 | 59.9 ± 1.2 | 58.3 ± 0.7 | 0.12GB ( |
| MiniLM6 | 70.0 ± 0.1 | 34.9 ± 0.1 | 38.1 ± 0.1 | 43.8 ± 0.2 | 1.03GB (2.0×) | 55.9 ± 0.9 | 62.1 ± 0.8 | 58.8 ± 0.6 | 0.69GB (2.0×) |
| MiniLM4 | 69.6 ± 0.2 | 34.0 ± 0.2 | 37.6 ± 0.2 | 43.2 ± 0.3 | 0.17GB ( | 56.8 ± 0.9 | 60.5 ± 1.0 | 58.6 ± 0.6 | 0.12GB ( |
| UniLM4 | 69.6 ± 0.1 | 34.4 ± 0.2 | 37.7 ± 0.1 | 43.4 ± 0.2 | 0.69GB (3.0×) | 56.1 ± 0.9 | 60.6 ± 0.9 | 58.2 ± 0.5 | 0.46GB (3.0×) |
| UniLM2 | 68.9 ± 0.2 | 33.6 ± 0.2 | 36.8 ± 0.2 | 42.5 ± 0.1 | 0.35GB (5.9×) | 53.8 ± 0.8 | 59.1 ± 1.0 | 56.3 ± 0.6 | 0.24GB (5.7×) |
| FetchSGD | 70.5 ± 0.4 | 35.2 ± 0.3 | 38.2 ± 0.3 | 44.0 ± 0.4 | 0.51GB (4.0×) | 57.5 ± 0.9 | 60.4 ± 1.1 | 59.0 ± 0.8 | 0.34GB (4.0×) |
| FedDropout | 70.5 ± 0.2 | 35.1 ± 0.2 | 38.3 ± 0.3 | 44.2 ± 0.3 | 1.23GB (1.7×) | 57.8 ± 1.0 | 61.0 ± 0.8 | 59.4 ± 0.6 | 0.82GB (1.7×) |
| SCAFFOLD | 70.7 ± 0.1 | 35.4 ± 0.2 | 38.7 ± 0.1 | 44.5 ± 0.2 | 2.73GB (0.8×) | 61.9 ± 0.9 | 60.3 ± 0.5 | 2.74GB (0.5×) | |
| FedPAQ (16-bit) | 1.03GB (2.0×) | 58.4 ± 1.1 | 61.2 ± 0.8 | 59.7 ± 0.7 | 0.69GB (2.0×) | ||||
| FedPAQ (8-bit) | 70.2 ± 0.3 | 35.0 ± 0.3 | 38.1 ± 0.3 | 44.0 ± 0.4 | 0.51GB (4.0×) | 56.5 ± 1.2 | 59.4 ± 0.9 | 57.9 ± 0.8 | 0.34GB (4.0×) |
| FedKD4 | 0.19GB (10.8×) | 0.12GB ( | |||||||
| FedKD2 | 70.5 ± 0.1 | 35.3 ± 0.2 | 38.6 ± 0.1 | 44.3 ± 0.2 | 0.11GB ( | 58.2 ± 0.7 | 0.07GB ( | ||
Local: learning model only on local data. Cen: learning on centralized datasets. Fed: a standard federated learning method FedAvg. Subscript numbers indicate the number of model hidden layers. The best Fscore (bold) of federated methods on ADR is significantly better than the second best one (underline) at the level of p < 0.05. The results show that the standard federated learning method can lead to heavy communication overheads, and FedKD can achieve promising results with much lower communication costs than FedAvg and other communication-efficient federated learning methods.
Fig. 1Performance comparison of different federated learning methods in the medical NER task.
The histogram height represents the Fscore of the corresponding method. The error bars represent the mean values with 95% confidence intervals (n = 5 independent experiments). The results show that FedKD is more effective than other compared federated learning methods in handling non-IID data.
Fig. 2Influence of mutual distillation on the mentee and mentor models.
The mean values of AUC scores on MIND and F1 scores on ADR as well as their 95% confidence intervals are illustrated (n = 5 independent experiments). We compare the performance of mentors and the four-layer or two-layer mentee models when mutual distillation (MD) is used or not. The results show that mutual distillation can improve the performance of both mentee and mentor models, which is because useful knowledge can be reciprocally transferred between the mentee and mentor.
Fig. 3Effectiveness of the adaptive mutual distillation techniques in FedKD.
The mean results with 95% confidence intervals are presented (n = 5 independent experiments). We compare the results by removing the adaptive hidden loss, adaptive mutual distillation loss or the adaptive loss weighting method from FedKD. Adaptive hidden loss: the distillation loss function that aims to transfer knowledge encoded by the hidden states and intermediate results of models, where the loss intensity is weighted by the prediction loss of mentee and mentor models. Adaptive MD loss: the distillation loss function that aims to distill knowledge from the output soft labels of models, and its intensity is also controlled by the prediction loss. Adaptive loss weighting: the mechanism that weights the two distillation losses based on the summation of cross-entropy losses of the mentor and mentee models. We find the performance drops when either of them is removed, which verifies their contributions to federated model learning and distillation.
Fig. 4Analysis of gradient energy distribution.
a Cumulative energy distributions of singular values of different types of parameter gradient matrices at the beginning of model training. b Cumulative energy distributions of singular values at the end of model training. c Evolution of the number of required singular values during model training under a singular value energy cutoff threshold T = 0.95. WQ: query parameters, WK: key parameters, WV: value parameters, W: feed-forward network parameters. The results show that the gradients are usually low-rank, and they have more high-frequency components after more rounds of model training. Thus, a relatively higher energy threshold needs to be used to keep higher gradient precision at the end of model training for better model accuracy.
Fig. 5The framework of our FedKD approach.
The local data is used to train the local mentor model and global mentee model. Both models are learned from local labeled data as well as the prediction and hidden results of each other. The local gradients are decomposed before uploading to the server, and then reconstructed on the server for aggregation. The aggregated global gradients are further decomposed and distributed to clients for local updates.
| 1: Setting the mentor learning rate |
| 2: Setting the hyperparameters |
| 3: |
| 4: Initialize parameters |
| 5: |
| 6: |
| 7: |
| 8: |
| 9: Clients encrypt |
| 10: Clients upload |
| 11: Server decrypts |
| 12: Server reconstructs |
| 13: Global gradients |
| 14: |
| 15: |
| 16: |
| 17: |
| 18: Server encrypts |
| 19: Server distributes |
| 20: Clients decrypt |
| 21: Clients reconstructs |
| 22: Θ |
| 23: |
| 24: |
| |
| 25: Compute task losses |
| 26: Compute losses |
| 27: |
| 28: |
| 29: Compute local mentor gradients |
| 30: Compute local mentee gradients |
| 31: |