|
|
--- |
|
|
tags: |
|
|
- pytorch |
|
|
- vae |
|
|
- image-generation |
|
|
- cc3m |
|
|
license: mit |
|
|
datasets: |
|
|
- pixparse/cc3m-wds |
|
|
library_name: custom |
|
|
model-index: |
|
|
- name: vae-256px-32z |
|
|
results: |
|
|
- task: |
|
|
type: image-generation |
|
|
dataset: |
|
|
name: google-research-datasets/conceptual-captions |
|
|
type: image |
|
|
metrics: |
|
|
- type: FID |
|
|
value: 2.557451009750366 |
|
|
- type: LPIPS |
|
|
value: 0.05679609028979091 |
|
|
- type: ID-similarity |
|
|
value: 0.000406394264995487 |
|
|
--- |
|
|
|
|
|
# UNet-Style VAE for 256x256 Image Reconstruction |
|
|
|
|
|
This model is a UNet-style Variational Autoencoder (VAE) trained on the [CC3M](https://huggingface.co/datasets/pixparse/cc3m-wds) dataset for high-quality image reconstruction and generation. It integrates adversarial, perceptual, and identity-preserving loss terms to improve semantic and visual fidelity. |
|
|
|
|
|
## Architecture |
|
|
|
|
|
- **Encoder/Decoder**: Multi-scale UNet architecture |
|
|
- **Latent Space**: **32**-channel latent bottleneck with reparameterization (mu, logvar) |
|
|
- **Losses**: |
|
|
- L1 reconstruction loss |
|
|
- KL divergence with annealing |
|
|
- LPIPS perceptual loss (VGG backbone) |
|
|
- Identity loss via MoCo-v2 embeddings |
|
|
- Adversarial loss via Patch Discriminator w/ Spectral Norm |
|
|
|
|
|
$$ |
|
|
\mathcal{L}_{total} = \mathcal{L}_{recon} + \mathcal{L}_{PIPS} + 0.5 * \mathcal{L}_{GAN} + 0.1 *\mathcal{L}_{ID} + 10^{-6} *\mathcal{L}_{KL} |
|
|
$$ |
|
|
|
|
|
## Reconstructions |
|
|
|
|
|
| Input | Output | |
|
|
| -------------------------- | --------------------------- | |
|
|
|  |  | |
|
|
|
|
|
## Training Config |
|
|
|
|
|
| Hyperparameter | Value | |
|
|
| ---------------- | ---------------------------------- | |
|
|
| Dataset | CC3M (850k images) | |
|
|
| Image Resolution | 256 x 256 | |
|
|
| Batch Size | 16 | |
|
|
| Optimizer | AdamW | |
|
|
| Learning Rate | 5e-5 | |
|
|
| Precision | bf16 (mixed precision) | |
|
|
| Total Steps | 210,000 | |
|
|
| GAN Start Step | 50,000 | |
|
|
| KL Annealing | Yes (10% of training) | |
|
|
| Augmentations | Crop, flip, jitter, blur, rotation | |
|
|
|
|
|
Trained using a cosine learning rate schedule with gradient clipping and automatic mixed precision (`torch.cuda.amp`) |
|
|
|
|
|
## Note: Model class will be available for usage once original repository becomes open-source (W.I.P) |