File size: 2,566 Bytes
64b286d
89a620e
6adfdc7
 
 
 
89a620e
 
6adfdc7
 
64b286d
6adfdc7
 
 
 
 
 
 
 
 
 
 
 
 
 
64b286d
 
6adfdc7
64b286d
89a620e
64b286d
89a620e
64b286d
89a620e
 
 
 
 
 
 
 
64b286d
89a620e
 
 
64b286d
89a620e
64b286d
6adfdc7
 
89a620e
64b286d
89a620e
64b286d
6adfdc7
 
 
 
 
 
 
 
 
 
 
 
64b286d
89a620e
51f1435
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
---
tags:
  - pytorch
  - vae
  - image-generation
  - cc3m
license: mit
datasets:
  - pixparse/cc3m-wds
library_name: custom
model-index:
  - name: vae-256px-32z
    results:
      - task:
          type: image-generation
        dataset:
          name: google-research-datasets/conceptual-captions
          type: image
        metrics:
          - type: FID
            value: 2.557451009750366
          - type: LPIPS
            value: 0.05679609028979091
          - type: ID-similarity
            value: 0.000406394264995487
---

# UNet-Style VAE for 256x256 Image Reconstruction

This model is a UNet-style Variational Autoencoder (VAE) trained on the [CC3M](https://huggingface.co/datasets/pixparse/cc3m-wds) dataset for high-quality image reconstruction and generation. It integrates adversarial, perceptual, and identity-preserving loss terms to improve semantic and visual fidelity.

## Architecture

- **Encoder/Decoder**: Multi-scale UNet architecture
- **Latent Space**: **32**-channel latent bottleneck with reparameterization (mu, logvar)
- **Losses**:
  - L1 reconstruction loss
  - KL divergence with annealing
  - LPIPS perceptual loss (VGG backbone)
  - Identity loss via MoCo-v2 embeddings
  - Adversarial loss via Patch Discriminator w/ Spectral Norm

$$
\mathcal{L}_{total} = \mathcal{L}_{recon} + \mathcal{L}_{PIPS} + 0.5 * \mathcal{L}_{GAN} + 0.1 *\mathcal{L}_{ID} + 10^{-6} *\mathcal{L}_{KL}
$$

## Reconstructions

| Input                      | Output                      |
| -------------------------- | --------------------------- |
| ![input](./input_grid.png) | ![output](./recon_grid.png) |

## Training Config

| Hyperparameter   | Value                              |
| ---------------- | ---------------------------------- |
| Dataset          | CC3M (850k images)                 |
| Image Resolution | 256 x 256                          |
| Batch Size       | 16                                 |
| Optimizer        | AdamW                              |
| Learning Rate    | 5e-5                               |
| Precision        | bf16 (mixed precision)             |
| Total Steps      | 210,000                            |
| GAN Start Step   | 50,000                             |
| KL Annealing     | Yes (10% of training)              |
| Augmentations    | Crop, flip, jitter, blur, rotation |

Trained using a cosine learning rate schedule with gradient clipping and automatic mixed precision (`torch.cuda.amp`)

## Note: Model class will be available for usage once original repository becomes open-source (W.I.P)