File size: 2,566 Bytes
64b286d 89a620e 6adfdc7 89a620e 6adfdc7 64b286d 6adfdc7 64b286d 6adfdc7 64b286d 89a620e 64b286d 89a620e 64b286d 89a620e 64b286d 89a620e 64b286d 89a620e 64b286d 6adfdc7 89a620e 64b286d 89a620e 64b286d 6adfdc7 64b286d 89a620e 51f1435 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
---
tags:
- pytorch
- vae
- image-generation
- cc3m
license: mit
datasets:
- pixparse/cc3m-wds
library_name: custom
model-index:
- name: vae-256px-32z
results:
- task:
type: image-generation
dataset:
name: google-research-datasets/conceptual-captions
type: image
metrics:
- type: FID
value: 2.557451009750366
- type: LPIPS
value: 0.05679609028979091
- type: ID-similarity
value: 0.000406394264995487
---
# UNet-Style VAE for 256x256 Image Reconstruction
This model is a UNet-style Variational Autoencoder (VAE) trained on the [CC3M](https://huggingface.co/datasets/pixparse/cc3m-wds) dataset for high-quality image reconstruction and generation. It integrates adversarial, perceptual, and identity-preserving loss terms to improve semantic and visual fidelity.
## Architecture
- **Encoder/Decoder**: Multi-scale UNet architecture
- **Latent Space**: **32**-channel latent bottleneck with reparameterization (mu, logvar)
- **Losses**:
- L1 reconstruction loss
- KL divergence with annealing
- LPIPS perceptual loss (VGG backbone)
- Identity loss via MoCo-v2 embeddings
- Adversarial loss via Patch Discriminator w/ Spectral Norm
$$
\mathcal{L}_{total} = \mathcal{L}_{recon} + \mathcal{L}_{PIPS} + 0.5 * \mathcal{L}_{GAN} + 0.1 *\mathcal{L}_{ID} + 10^{-6} *\mathcal{L}_{KL}
$$
## Reconstructions
| Input | Output |
| -------------------------- | --------------------------- |
|  |  |
## Training Config
| Hyperparameter | Value |
| ---------------- | ---------------------------------- |
| Dataset | CC3M (850k images) |
| Image Resolution | 256 x 256 |
| Batch Size | 16 |
| Optimizer | AdamW |
| Learning Rate | 5e-5 |
| Precision | bf16 (mixed precision) |
| Total Steps | 210,000 |
| GAN Start Step | 50,000 |
| KL Annealing | Yes (10% of training) |
| Augmentations | Crop, flip, jitter, blur, rotation |
Trained using a cosine learning rate schedule with gradient clipping and automatic mixed precision (`torch.cuda.amp`)
## Note: Model class will be available for usage once original repository becomes open-source (W.I.P) |