| Literature DB >> 35684708 |
Yubo Shao1, Kaikai Zhao2, Zhiwen Cao3, Zhehao Peng1, Xingang Peng4, Pan Li1, Yijie Wang2, Jianzhu Ma5.
Abstract
It is hard to directly deploy deep learning models on today's smartphones due to the substantial computational costs introduced by millions of parameters. To compress the model, we develop an ℓ0-based sparse group lasso model called MobilePrune which can generate extremely compact neural network models for both desktop and mobile platforms. We adopt group lasso penalty to enforce sparsity at the group level to benefit General Matrix Multiply (GEMM) and develop the very first algorithm that can optimize the ℓ0 norm in an exact manner and achieve the global convergence guarantee in the deep learning context. MobilePrune also allows complicated group structures to be applied on the group penalty (i.e., trees and overlapping groups) to suit DNN models with more complex architectures. Empirically, we observe the substantial reduction of compression ratio and computational costs for various popular deep learning models on multiple benchmark datasets compared to the state-of-the-art methods. More importantly, the compression models are deployed on the android system to confirm that our approach is able to achieve less response delay and battery consumption on mobile phones.Entities:
Keywords: convolutional neural network; deep learning; mobile computing; model compression; pruning network
Mesh:
Year: 2022 PMID: 35684708 PMCID: PMC9185446 DOI: 10.3390/s22114081
Source DB: PubMed Journal: Sensors (Basel) ISSN: 1424-8220 Impact factor: 3.847
Figure 1Observations of different strategies’ pruned filter matrix for hardware acceleration with software implementation of convolution in cuDNN. (a) General Matrix Multiply (GEMM) is applied in cuDNN. (b) Different strategies such as no pruning, individual sparsity, column-wise group sparsity, and both individual sparsity and column-wise group sparsity on pruning the filter matrix. (c) The pruned filter matrix implemented in cuDNN and determined whether it can be used for hardware acceleration or not.
Figure 2Overview of the proposed MobilePrune method. (a) Group sparsity for weights of a neuron for fully connected layers. (b) Sparsity on individual weights for fully connected layers. (c) Pruning strategy for fully connected layers and their effect where sparsity is induced on both neuron-wise groups and individual weights. (d) Group and individual sparsity for convolutional layers.
Comparison of pruned models with state-of-the-art methods on different datasets – MNIST, CIFAR-10, and Tiny-ImageNet, respectively. (We highlight our MobilePrune results and mark the best performance as blue among different methods for each model in each dataset).
| Dataset | Model | Methods | Base/Pruned Accuracy (%) | Original/Remaining Parameters (Mil) | FLPOs (Mil) |
|---|---|---|---|---|---|
| MNIST | BC-GNJ [ | 98.40/98.20 | 267.00/28.73 | 28.64 | |
| BC-GHS [ | 98.40/98.20 | 267.00/28.17 | 28.09 | ||
| LeNet-300-100 | L0 [ | -/ | - | 69.27 | |
| L0-sep [ | -/98.20 | - | 26.64 | ||
|
|
|
|
| ||
| SBP [ | -/ | - | 212.80 | ||
| BC-GNJ [ | 99.10/99.00 | 431.00/3.88 | 282.87 | ||
| BC-GHS [ | 99.10/99.00 | 431.00/2.59 | 153.38 | ||
| LeNet-5 | L0 [ | -/99.10 | - | 1113.40 | |
| L0-sep [ | -/99.00 | - | 390.68 | ||
|
|
|
|
| ||
| CIFAR-10 | Original [ | -/92.45 | 15.00/- | 313.5 | |
| PF [ | -/93.40 | 15.00/5.4 | 206.3 | ||
| VGG-like | SBP [ | 92.80/92.50 | 15.00/- | 136.0 | |
| SBPa [ | 92.80/91.00 | 15.00/- | 99.20 | ||
| VIBNet [ | -/ | 15.00/0.87 | 86.82 | ||
|
|
|
|
| ||
| C-OBD [ | 95.30/95.27 | 7.42/2.92 | 488.85 | ||
| C-OBS [ | 95.30/95.30 | 7.42/3.04 | 378.22 | ||
| ResNet32 | Kron-OBD [ | 95.30/95.30 | 7.42/3.26 | 526.17 | |
| Kron-OBS [ | 95.30/95.46 | 7.42/3.23 | 524.52 | ||
| EigenDamage [ | 95.30/95.28 | 7.42/2.99 | 457.46 | ||
|
|
|
| |||
| NN slimming [ | 61.56/40.05 | 20.12/5.83 | 158.62 | ||
| C-OBD [ | 61.56/47.36 | 20.12/4.21 | 481.90 | ||
| C-OBS [ | 61.56/39.80 | 20.12/6.55 | 210.05 | ||
| Tiny-ImageNet | VGG-19 | Kron-OBD [ | 61.56/44.41 | 20.12/4.72 | 298.28 |
| Kron-OBS [ | 61.56/44.54 | 20.12/5.26 |
| ||
| EigenDamage [ | 61.56/ | 20.12/5.21 | 408.17 | ||
|
|
|
|
Results about learning filter shapes in LeNet-5. (We highlight our MobilePrune results).
| Method | Base/Pruned Accuracy (%) | Filter Size | Remaining Filters | Remaining Parameters | FLOPs (K) |
|---|---|---|---|---|---|
| Baseline | - | 25–500 | 20–50 | 500–25,000 | 2464 |
| SSL [ | 99.10/99.00 | 7–14 | 1–50 | - | 63.82 |
|
|
|
|
|
|
|
Comparison of pruning method on the desktop with state-of-the-art methods for pruning accuracy, pruning rate and response delay on HAR datasets—WISDM, HCI-HAR, and PAMAP2, respectively. (We highlight our MobilePrune results and mark the best performance as blue among different penalties for each dataset).
| Dataset | Penalty | Base/Pruned Accuracy (%) | Parameter Nonzero (%) | Parameter Remaining (%) | Node Remaining (%) | Base/Pruned Response Delay (s) | Time Saving Percentage (%) |
|---|---|---|---|---|---|---|---|
| WISDM | 94.72/94.79 | 63.36 | 100.00 | 100.00 | 0.38/0.39 | 0.00 | |
| 94.30/93.84 | 13.58 | 46.26 | 68.16 | 0.38/0.24 | 36.84 | ||
| 94.61/94.54 | 56.28 | 90.46 | 95.12 | 0.38/0.35 | 7.89 | ||
| Group lasso | 94.68/94.32 | 48.23 | 89.73 | 94.73 | 7.89 | ||
| 94.81/ | 17.91 | 53.41 | 73.83 | 0.41/0.26 | 36.59 | ||
|
|
|
|
|
|
|
| |
| UCI-HAR | 88.49 | 100.00 | 100.00 | 0.84/0.80 | 4.76 | ||
| 90.46/90.33 | 81.58 | 98.47 | 99.22 | 0.81/0.82 | 0.00 | ||
| 91.01/90.94 | 88.35 | 100.00 | 100.00 | 0.00 | |||
| Group lasso | 90.80/90.84 | 82.91 | 100.00 | 100.00 | 0.83/0.78 | 6.02 | |
| 91.11/91.04 | 81.21 | 97.70 | 98.83 | 0.84/0.80 | 4.76 | ||
|
|
|
|
|
|
|
| |
| PAMAP2 | 93.15/93.07 | 69.27 | 100.00 | 100.00 | 0.41/0.41 | 0.00 | |
| 95.22/95.29 | 1.46 | 7.28 | 19.73 | 0.40/0.08 | 80.00 | ||
| 92.08/92.09 | 65.32 | 94.93 | 97.27 | 0.41/0.39 | 4.88 | ||
| Group lasso | 93.30/93.28 | 61.78 | 100.00 | 100.00 | 0.41/0.41 | 0.00 | |
| 96.87/ | 2.67 | 9.72 | 26.17 | 75.00 | |||
|
|
|
|
|
|
|
|
Comparison of pruning method on the mobile devices with other state-of-the-art pruning methods for computational cost and battery usage on HAR dataset—WISDM, HCI-HAR, and PAMAP2, respectively. (We highlight our MobilePrune results and mark the best performance as blue among different penalties for each device in each dataset).
| Dataset | Device | Penalty | Base/Pruned Response Delay (s) | Time Saving Percentage (%) | Based/Pruned Device Estimated Battery Use (%/h) | Battery Saving Percentage (%) |
|---|---|---|---|---|---|---|
| WISDM | Huawei P20 | 1.40/1.27 | 9.29 | 1.41 | ||
| 1.33/0.71 | 46.62 | 0.74/0.65 | 12.16 | |||
| 1.28/1.21 | 5.47 | 0.74/0.77 | 0.00 | |||
| Group lasso | 1.27/1.27 | 0.00 | 0.74/0.77 | 0.00 | ||
| 35.20 | 0.74/0.68 | 8.11 | ||||
|
|
|
|
|
| ||
| OnePlus 8 Pro | 0.57/0.49 | 14.04 | 0.34/0.32 | 5.88 | ||
| 0.48/0.34 | 29.17 | 0.35/0.30 | 14.29 | |||
| 0.48/0.40 | 16.67 | 0.34/0.34 | 0.00 | |||
| Group lasso | 0.49/0.45 | 8.16 | 0.34/0.35 | 0.00 | ||
| 0.48/0.33 | 31.25 | 0.35/0.30 | 14.29 | |||
|
|
|
|
|
| ||
| HCI-HAR | Huawei P20 | 1.43/1.43 | 0.00 | 0.84/0.84 | 0.00 | |
| 1.42/1.42 | 0.00 | 0.85/0.84 | 1.18 | |||
| 1.43/1.43 | 0.00 | 0.84/0.84 | 0.00 | |||
| Group lasso | 1.43/1.43 | 0.00 | 0.84/0.82 | 2.38 | ||
| 1.42/1.41 | 0.70 | 0.85/0.82 | 3.53 | |||
|
|
|
|
|
| ||
| OnePlus 8 Pro | 0.00 | 0.35/0.35 | 0.00 | |||
| 0.54/0.51 | 5.56 | 0.37/0.36 | 2.70 | |||
| 0.54/0.53 | 1.85 | 0.37/0.37 | 0.00 | |||
| Group lasso | 0.53/0.52 | 1.89 | 0.36/0.36 | 0.00 | ||
| 0.53/0.52 | 1.89 | 0.36/0.36 | 0.00 | |||
|
|
|
|
|
| ||
| PAMAP2 | Huawei P20 | 0.00 | 0.00 | |||
| 2.74/0.45 | 83.58 | 0.79/0.53 | 32.91 | |||
| 2.67/2.56 | 4.12 | 0.78/0.78 | 0.00 | |||
| Group lasso | 2.67/2.68 | 0.00 | 0.78/0.78 | 0.00 | ||
| 2.69/0.55 | 79.55 | 0.79/0.57 | 27.85 | |||
|
|
|
|
|
| ||
| OnePlus 8 Pro | 0.94/0.93 | 1.06 | 0.88/0.88 | 0.00 | ||
| 73.12 | 0.87/0.55 | 36.78 | ||||
| 0.93/0.91 | 2.15 | 0.88/0.87 | 1.14 | |||
| Group lasso | 0.94/0.95 | 0.00 | 0.89/0.89 | 0.00 | ||
| 0.95/0.29 | 69.47 | 0.88/0.59 | 32.95 | |||
|
|
|
|
|
|
Alation studies on various network models. (We mark the best performance as blue among different penalties for each model).
| Network Model | Penalty | Base/Pruned Accuracy (%) | Original/Remaining Parameters (Mil) | FLOPs | Sparsity (%) |
|---|---|---|---|---|---|
| LetNet-300 | 98.24/ | 267 K/57.45 K | 143.20 | 21.55 | |
| Group lasso | 98.24/98.17 | 267 K/32.06 K | 39.70 | 12.01 | |
| 98.24/98.00 | 267 K/15.80 K | 25.88 | 5.93 | ||
| 98.24/98.23 | 267 K/ |
|
| ||
| LetNet-5 | 99.12/ | 431 K/321.0 K | 2293.0 | 74.48 | |
| Group lasso | 99.12/99.11 | 431 K/8.81 K | 187.00 | 2.04 | |
| 99.12/99.03 | 431 K/9.98 K | 183.83 | 2.32 | ||
| 99.12/99.11 | 431 K/ |
|
| ||
| VGG-like | 92.96/ | 15 M/3.39 M | 210.94 | 22.6 | |
| Group lasso | 92.96/92.47 | 15 M/0.84 M | 78.07 | 5.60 | |
| 92.96/92.90 | 15 M/0.61 M | 134.35 | 4.06 | ||
| 92.96/92.94 | 15 M/ |
|
| ||
| ResNet-32 | 95.29/ | 7.42 M/6.74 M | 993.11 | 90.84 | |
| Group lasso | 95.29/95.30 | 7.42 M/3.03 M | 373.09 | 40.84 | |
| 95.29/95.04 | 7.42 M/5.66 M | 735.12 | 76.28 | ||
| 95.29/95.47 | 7.42 M/ |
|
| ||
| VGG-19 | 61.56/ | 138 M/19.29 M | 1519.23 | 13.98 | |
| Group lasso | 61.56/53.25 | 138 M/5.93 M | 683.99 | 4.30 | |
| 61.56/53.97 | 138 M/ | 1282.82 |
| ||
| 61.56/56.27 | 138 M/4.05 M |
| 2.93 |
List of hyper-parameters and their values(“-” denotes “not applicable”).
| Hyper-Parameter | LeNet300 | LeNet5 | VGG-Like | ResNet-32 | VGG-19 | Description |
|---|---|---|---|---|---|---|
| learning rate | 1 × 10−3 | 1 × 10−3 | 1 × 10−3 | 1 × 10−3 | 1 × 10−3 | The learning rate used in retraining process |
| gradient momentum | 0.9 | 0.9 | 0.9 | 0.9 | 0.9 | The gradient momentum used in retraining process |
| weight decay | 1 × 10−4 | 1 × 10−5 | 5 × 10−4 | 1 × 10−4 | 1 × 10−4 | The weight decay factor used in retraining process |
| minibatch size | 1 × 102 | 6 × 102 | 1 × 103 | 3 × 102 | 4 × 102 | The number of training samples over which each SGD update is computed during the retraining process |
| 4 × 10−4 | 2 × 10−4 | 1 × 10−6 | 1 × 10−8 | 1 × 10−10 | The shrinkage coefficient for | |
| channel factor | - | 1 × 10−3 | 1 × 10−3–1 × 10−2 1 | 5 × 10−2 | 5 × 10−2 | The shrinkage coefficient of channels for group Lasso |
| neuron factor | 2 × 10−4 | 2 × 10−4 | 1 × 10−4 | 0 | 1 × 10−2 | The shrinkage coefficient of neurons for group Lasso |
| filter size factor | - | 1 × 10−3 | 1 × 10−4 | 1 × 10−4 | 1 × 10−4 | The shrinkage coefficient of filter shapes for group Lasso |
| pruning frequency (epochs/minibatches) | 10 | 10 | 1 | 2 | 1 | No. of epochs(LeNet)/minibatches(VGGNet/ResNet) for pruning before retraining |
| retraining epochs | 30 | 30 | 20 | 30 | 15 | The number of retraining epochs after pruning |
| iterations | 74 | 102 | 63 | 2 | 66 | The number of iterations for obtaining the final results |
1 On VGG-like, the channel factor is adaptive and it is increased by 0.001 if its cross-entropy loss is not greater than the loss before performing pruning for the current mini-batch. Its range is [0.001,0.01].
Alation studies on LetNet-5 (Architecture: 20-50-800-500).
| Penalty | Base/Pruned Accuracy (%) | Original/Remaining Parameters (K) | Pruned Architecture | Filter Size | FLOPs (K) | Sparsity (%) |
|---|---|---|---|---|---|---|
| 99.12/99.20 | 431/321.00 | 20-50-800-500 | 25–500 | 2293.0 | 74.48 | |
| Group Lasso | 99.12/99.11 | 431/8.81 | 4-19-301-29 | 25–99 | 187.00 | 2.04 |
| 99.12/99.03 | 431/9.98 | 4-17-271-82 | 23–99 | 183.83 | 2.32 | |
| 99.12/99.11 | 431/2.31 | 5-14-151-57 | 16–65 | 113.50 | 1.97 |
Alation studies on VGG-like.
| Penalty | Base/Pruned Accuracy (%) | Original/Remaining Parameters (Mil) | Pruned Architecture | FLOPs (Mil) |
|---|---|---|---|---|
| 92.96/93.40 | 15/3.39 | 18-43-92-99-229-240-246-507-504-486-241-114-428-168 | 210.94 | |
| Group Lasso | 92.96/92.47 | 15/0.84 | 17-43-89-99-213-162-93-42-32-28-8-5-429-168 | 78.07 |
| 92.96/92.90 | 15/0.61 | 17-43-92-99-229-240-246-323-148-111-41-39-159-161 | 134.35 | |
| 92.96/92.94 | 15/0.60 | 17-43-87-99-201-185-80-37-27-25-9-4-368-167 | 77.83 |
Alation studies on ResNet-32.
| Penalty | Base/Pruned Accuracy (%) | Original/Remaining Parameters (Mil) | FLOPs (Mil) | Sparsity (%) |
|---|---|---|---|---|
| 95.29/95.68 | 7.42/6.74 | 993.11 | 90.84 | |
| Group Lasso | 95.29/95.30 | 7.42/3.43 | 393.09 | 45.95 |
| 95.29/95.04 | 7.42/5.66 | 735.12 | 76.28 | |
| 95.29/95.47 | 7.42/2.93 | 371.30 | 39.49 |
Alation studies on VGG19.
| Penalty | Test Accuracy (%) | Remaining Parameters (Mil) | Pruned Architecture | FLOPs (Mil) |
|---|---|---|---|---|
| Baseline | 61.56 | 20.12 | 64-64-128-128-256-256-256-256-512-512-512-512-512-512-512-512 | 1592.53 |
| 61.99 | 19.29 | 45-64-114-128-256-256-256-256-512-511-512-509-512-512-512-512 | 1519.23 | |
| Group Lasso | 53.25 | 5.93 | 23-61-80-128-122-114-164-253-255-322-412-462-23-93-129-512 | 683.99 |
| 53.97 | 0.21 | 29-64-109-128-254-246-254-256-510-509-509-509-512-512-484-512 | 1282.82 | |
| 56.27 | 4.05 | 19-48-57-102-79-83-100-179-219-273-317-341-256-158-116-512 | 407.37 |
The effect of the coefficient of norm penalty on VGG-like.
| Base/Pruned Accuracy (%) | Original/Remaining Parameters (Mil) | Pruned Architecture | FLOPs (Mil) | |
|---|---|---|---|---|
| 1 × 10−4 | 92.96/89.77 | 15/0.06 | 17-43-83-99-161-105-57-28-24-15-11-4-104-157 | 56.43 |
| 1 × 10−5 | 92.96/92.19 | 15/0.30 | 16-43-85-99-171-155-75-33-23-18-10-3-264-167 | 66.82 |
| 1 × 10−6 | 92.96/92.94 | 15/0.60 | 17-43-87-99-201-185-80-37-27-25-9-4-368-167 | 77.83 |
| 1 × 10−7 | 92.96/92.54 | 15/0.74 | 17-43-87-99-213-188-91-40-26-27-9-4-400-168 | 81.64 |
The effect of the coefficient of norm penalty on ResNet-32.
| Base/Pruned Accuracy (%) | Original/Remaining Parameters (Mil) | FLOPs (Mil) | Sparsity | |
|---|---|---|---|---|
| 1 × 10−6 | 95.29/95.11 | 7.42/2.06 | 330.90 | 27.76 |
| 1 × 10−7 | 95.29/95.33 | 7.42/2.72 | 369.36 | 36.66 |
| 1 × 10−8 | 95.29/95.47 | 7.42/2.93 | 77.83 | 39.49 |
| 1 × 10−9 | 95.29/95.44 | 7.42/3.02 | 372.98 | 40.70 |
Impact of different cross-validation fold numbers and learning rates on the proposed sparse group lasso approach on each HAR dataset—WISDM, UCI-HAR, and PAMAP2, respectively. (We highlight our selection in both fold number and learning rate for each dataset).
| Dataset | Type | Value | Base/Pruned Accuracy (%) | Parameter Nonzero (%) | Parameter Remaining (%) | Node Remaining (%) |
|---|---|---|---|---|---|---|
| WISDM | Fold Number | 1 | 93.52/92.68 | 11.64 | 32.49 | 57.42 |
| 2 | 94.88/93.70 | 10.03 | 30.35 | 55.08 | ||
| 3 | 94.45/93.48 | 9.45 | 27.97 | 52.13 | ||
|
|
|
|
|
| ||
| 5 | 93.52/92.68 | 11.64 | 32.49 | 57.42 | ||
| Learning Rate | 1.0 | 89.55/86.72 | 27.09 | 93.50 | 96.68 | |
| 5.0 | 92.93/84.36 | 9.41 | 40.44 | 64.06 | ||
|
|
|
|
|
| ||
| 1.5 | 94.96/94.88 | 10.38 | 27.26 | 52.54 | ||
| 1.0 | 94.65/94.57 | 10.54 | 32.38 | 56.84 | ||
| UCI-HAR | Fold Number | 1 | 78.42/78.08 | 15.53 | 31.99 | 56.64 |
| 2 | 89.89/89.28 | 32.49 | 64.29 | 80.27 | ||
| 3 | 79.13/79.37 | 16.02 | 32.25 | 56.84 | ||
| 4 | 78.22/78.22 | 18.69 | 40.02 | 63.48 | ||
|
|
|
|
|
| ||
| Learning Rate | 1.0 | 85.27/85.51 | 77.98 | 94.66 | 97.27 | |
| 5.0 | 89.38/89.24 | 16.69 | 85.77 | 92.58 | ||
|
|
|
|
|
| ||
| 1.5 | 90.94/90.91 | 16.69 | 31.04 | 56.45 | ||
| 2.0 | 90.40/90.43 | 13.24 | 29.10 | 54.10 | ||
| PAMAP2 | Fold Number |
|
|
|
|
|
| 2 | 92.29/92.28 | 1.27 | 3.15 | 10.35 | ||
| 3 | 96.49/96.28 | 1.81 | 4.74 | 14.84 | ||
| 4 | 95.08/94.99 | 1.20 | 3.42 | 10.55 | ||
| 5 | 94.81/94.81 | 1.46 | 3.71 | 11.52 | ||
| Learning Rate | 1.0 | 93.63/85.80 | 7.93 | 28.61 | 49.22 | |
| 5.0 | 94.25/93.89 | 3.90 | 11.81 | 28.32 | ||
|
|
|
|
|
| ||
| 1.5 | 96.57/96.62 | 1.12 | 2.36 | 7.62 | ||
| 2.0 | 94.89/94.99 | 0.68 | 2.02 | 7.62 |