| Literature DB >> 35455120 |
Xili Dai1,2, Shengbang Tong1, Mingyang Li3, Ziyang Wu4, Michael Psenka1, Kwan Ho Ryan Chan5, Pengyuan Zhai6, Yaodong Yu1, Xiaojun Yuan2, Heung-Yeung Shum4, Yi Ma1,3.
Abstract
This work proposes a new computational framework for learning a structured generative model for real-world datasets. In particular, we propose to learn a Closed-loop Transcriptionbetween a multi-class, multi-dimensional data distribution and a Linear discriminative representation (CTRL) in the feature space that consists of multiple independent multi-dimensional linear subspaces. In particular, we argue that the optimal encoding and decoding mappings sought can be formulated as a two-player minimax game between the encoder and decoderfor the learned representation. A natural utility function for this game is the so-called rate reduction, a simple information-theoretic measure for distances between mixtures of subspace-like Gaussians in the feature space. Our formulation draws inspiration from closed-loop error feedback from control systems and avoids expensive evaluating and minimizing of approximated distances between arbitrary distributions in either the data space or the feature space. To a large extent, this new formulation unifies the concepts and benefits of Auto-Encoding and GAN and naturally extends them to the settings of learning a both discriminative and generative representation for multi-class and multi-dimensional real-world data. Our extensive experiments on many benchmark imagery datasets demonstrate tremendous potential of this new closed-loop formulation: under fair comparison, visual quality of the learned decoder and classification performance of the encoder is competitive and arguably better than existing methods based on GAN, VAE, or a combination of both. Unlike existing generative models, the so-learned features of the multiple classes are structured instead of hidden: different classes are explicitly mapped onto corresponding independent principal subspaces in the feature space, and diverse visual attributes within each class are modeled by the independent principal components within each subspace.Entities:
Keywords: closed-loop transcription; linear discriminative representation; minimax game; rate reduction
Year: 2022 PMID: 35455120 PMCID: PMC9031319 DOI: 10.3390/e24040456
Source DB: PubMed Journal: Entropy (Basel) ISSN: 1099-4300 Impact factor: 2.738
Figure 1CTRL: A Closed-loop Transcription to an LDR. The encoder f has dual roles: it learns an LDR for the data via maximizing the rate reduction of and it is also a “feedback sensor” for any discrepancy between the data and the decoded . The decoder g also has dual roles: it is a “controller” that corrects the discrepancy between and and it also aims to minimize the overall coding rate for the learned LDR.
Figure 2Embeddings of Low-Dimensional Submanifolds in High-Dimensional Spaces. (blue) is the submanifold for the original data ; (red) is the image of under the mapping f, representing the learned feature ; and the green curve is the image of the feature under the decoding mapping g.
Figure 3Qualitative comparison on (a) MNIST, (b) CIFAR-10 and (c) ImageNet. First row: original ; other rows: reconstructed for different methods.
Quantitative comparison on MNIST and CIFAR-10. Average Inception scores (IS) [65] and FID scores [66]. ↑ means higher is better. ↓ means lower is better.
| Method | GAN | GAN (CTRL-Binary) | VAE-GAN | CTRL-Binary | CTRL-Multi | |
|---|---|---|---|---|---|---|
| MNIST | IS ↑ | 2.08 | 1.95 |
| 2.02 | 2.07 |
| FID ↓ | 24.78 | 20.15 | 33.65 |
| 16.47 | |
| CIFAR-10 | IS ↑ | 7.32 | 7.23 | 7.11 |
| 7.13 |
| FID ↓ | 26.06 | 22.16 | 43.25 |
| 23.91 | |
IS and FID scores of images reconstructed by LDR models learned with different feature dimensions. ↑ means higher is better. ↓ means lower is better.
| dim = 128 | dim = 512 | ||||
|---|---|---|---|---|---|
| CTRL-Binary | CTRL-Multi | CTRL-Binary | CTRL-Multi | ||
| CIFAR-10 | IS ↑ | 8.1 | 7.1 | 8.4 | 8.2 |
| FID ↓ | 19.6 | 23.6 | 18.7 | 20.5 | |
Figure 4Visualizing the alignment between and : and in the feature space for (a) MNIST, (b) CIFAR-10, and (c) ImageNet-10-Class.
Figure 5Visualizing the auto-encoding property of the learned closed-loop transcription () on MNIST, CIFAR-10, and ImageNet (zoom in for better visualization).
Comparison of CIFAR-10 and STL-10. Comparison with more existing methods and on ImageNet can be found in Table A10 in the Appendix A. ↑ means higher is better. ↓ means lower is better.
| Method | GAN Based Methods | VAE/GAN-Based Methods | |||||||
|---|---|---|---|---|---|---|---|---|---|
| SNGAN | CSGAN | LOGAN | VAE-GAN | NVAE | DC-VAE | CTRL-Binary | CTRL-Multi | ||
| CIFAR-10 | IS ↑ | 7.4 | 8.1 |
| 7.4 | - |
|
| 7.1 |
| FID ↓ | 29.3 | 19.6 |
| 39.8 | 50.8 |
|
| 23.9 | |
| STL-10 | IS ↑ |
| - | - | - | - | 8.1 | 8.4 | 7.7 |
| FID ↓ | 40.1 | - | - | - | - | 41.9 |
| 45.7 | |
Comparison on CIFAR-10, STL-10, and ImageNet. ↑ means higher is better. ↓ means lower is better.
| Method | CIFAR-10 | STL-10 | ImageNet | |||
|---|---|---|---|---|---|---|
| IS↑ | FID↓ | IS↑ | FID↓ | IS↑ | FID↓ | |
|
| ||||||
| DCGAN [ | 6.6 | - | 7.8 | - | - | - |
| SNGAN [ | 7.4 | 29.3 |
| 40.1 | - | 48.73 |
| CSGAN [ | 8.1 | 19.6 | - | - | - | - |
| LOGAN [ |
|
| - | - | - | - |
|
| ||||||
| VAE [ | 3.8 | 115.8 | - | - | - | - |
| VAE/GAN [ | 7.4 | 39.8 | - | - | - | - |
| NVAE [ | - | 50.8 | - | - | - | - |
| DC-VAE [ |
|
| 8.1 | 41.9 | - | - |
| CTRL-Binary (ours) |
|
| 8.4 |
| 7.74 |
|
| CTRL-Multi (ours) | 7.1 | 23.9 | 7.7 | 45.7 | 6.44 | 55.51 |
Figure A22Reconstruction results by LDR models learned with different feature dimensions.
Figure 6CIFAR-10 dataset. Visualization of top 5 reconstructed based on the closest distance of to each row (top 4) of principal components of data representations for class 7—‘Horse’ and class 8—‘Ship’.
Figure A7Reconstructed images from features close to the principal components learned for the 10 classes of CIFAR-10.
Figure 7CelebA dataset. (a): Sampling along three principal components that seem to correspond to different visual attributes; (b): Samples decoded by interpolating along the line between features of two distinct samples.
Figure A3Images generated from the interpolation between samples in different classes.
Classification accuracy on MNIST compared to classifier-based VAE methods [42]. Most of these VAE-based methods require auxiliary classifiers to boost classification performance.
| Method | VAE | Factor VAE | Guide-VAE | DC-VAE | CTRL-Binary | CTRL-Multi |
|---|---|---|---|---|---|---|
| MNIST | 97.12% | 93.65% | 98.51% | 98.71% | 89.12% | 98.30% |
Decoder for MNIST.
|
|
| 4 × 4, stride = 1, pad = 0 deconv. BN 256 ReLU |
| 4 × 4, stride = 2, pad = 1 deconv. BN 128 ReLU |
| 4 × 4, stride = 2, pad = 1 deconv. BN 64 ReLU |
| 4 × 4, stride = 2, pad = 1 deconv. 1 Tanh |
Encoder for MNIST.
| Gray image |
| 4 × 4, stride = 2, pad = 1 conv 64 lReLU |
| 4 × 4, stride = 2, pad = 1 conv. BN 128 lReLU |
| 4 × 4, stride = 2, pad = 1 conv. BN 256 lReLU |
| 4 × 4, stride = 1, pad = 0 conv 128 |
Decoder for CIFAR-10.
|
|
| dense |
| ResBlock up 256 |
| ResBlock up 256 |
| ResBlock up 256 |
| BN, ReLU, 3 × 3 conv, 3 Tanh |
Encoder for CIFAR-10.
| RGB image |
| ResBlock down 128 |
| ResBlock down 128 |
| ResBlock 128 |
| ResBlock 128 |
| ReLU |
| Global sum pooling |
| dense |
Decoder for STL-10.
|
|
| dense |
| ResBlock up 256 |
| ResBlock up 128 |
| ResBlock up 64 |
| BN, ReLU, 3 × 3 conv, 3 Tanh |
Encoder for STL-10.
| RGB image |
| ResBlock down 64 |
| ResBlock down 128 |
| ResBlock down 256 |
| ResBlock down 512 |
| ResBlock 1024 |
| ReLU |
| Global sum pooling |
| dense |
Decoder for CelebA-128, LSUN-bedroom-128, and ImageNet-128.
|
|
| dense |
| ResBlock up 1024 |
| ResBlock up 512 |
| ResBlock up 256 |
| ResBlock up 128 |
| ResBlock up 64 |
| BN, ReLU, 3 × 3 conv, 3 Tanh |
Encoder for CelebA-128, LSUN-bedroom-128, and ImageNet-128.
| RGB image |
| ResBlock down 64 |
| ResBlock down 128 |
| ResBlock down 256 |
| ResBlock down 512 |
| ResBlock down 1024 |
| ResBlock 1024 |
| ReLU |
| Global sum pooling |
| dense |
ID and correspond category for 10 classes of ImageNet.
| ID | Category |
|---|---|
| n02930766 | cab, hack, taxi, taxicab |
| n04596742 | wok |
| n02974003 | car wheel |
| n01491361 | tiger shark, Galeocerdo cuvieri |
| n01514859 | hen |
| n09472597 | volcano |
| n07749582 | lemon |
| n09428293 | seashore, coast, seacoast, sea-coast |
| n02504458 | African elephant, Loxodonta africana |
| n04285008 | sports car, sport car |
Three different objective functions for CTRL.
| Objective I: |
|
| Objective II: |
|
| Objective III: |
|
Ablation study the influence of spectral normalization. ↑ means higher is better. ↓ means lower is better.
| CTRL-Binary | CTRL-Multi | ||||
|---|---|---|---|---|---|
| Backbone = SNGAN | SN = True | SN = False | SN = True | SN = False | |
| CIFAR-10 | IS ↑ | 8.1 | 6.6 | 7.1 | 5.8 |
| FID ↓ | 19.6 | 27.8 | 23.9 | 41.5 | |
Ablation study on ImageNet about trade-off between batch size (BS) and network width (Channel #).
| Channel# = 1024 | Channel# = 512 | Channel# = 256 | |
|---|---|---|---|
| BS = 1800 | success | success | success |
| BS = 1600 | success | success | success |
| BS = 1024 | failure | success | success |
| BS = 800 | failure | failure | success |
| BS = 400 | failure | failure | failure |