shekkari21 commited on
Commit
3c45764
·
1 Parent(s): a0333da

Commiting all the super resolution files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. DEPLOYMENT_GUIDE.md +161 -0
  3. FUNCTION_MAPPING.md +142 -0
  4. README copy.md +235 -0
  5. SPACE_README.md +44 -0
  6. __pycache__/app.cpython-311.pyc +0 -0
  7. app.py +240 -0
  8. ldm/.DS_Store +0 -0
  9. ldm/__init__.py +2 -0
  10. ldm/__pycache__/__init__.cpython-311.pyc +0 -0
  11. ldm/__pycache__/__init__.cpython-312.pyc +0 -0
  12. ldm/__pycache__/util.cpython-311.pyc +0 -0
  13. ldm/__pycache__/util.cpython-312.pyc +0 -0
  14. ldm/__pycache__/util.cpython-38.pyc +0 -0
  15. ldm/models/__init__.py +2 -0
  16. ldm/models/__pycache__/__init__.cpython-311.pyc +0 -0
  17. ldm/models/__pycache__/__init__.cpython-312.pyc +0 -0
  18. ldm/models/__pycache__/autoencoder.cpython-311.pyc +0 -0
  19. ldm/models/__pycache__/autoencoder.cpython-312.pyc +0 -0
  20. ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
  21. ldm/models/autoencoder.py +145 -0
  22. ldm/modules/.DS_Store +0 -0
  23. ldm/modules/__init__.py +2 -0
  24. ldm/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  25. ldm/modules/__pycache__/__init__.cpython-312.pyc +0 -0
  26. ldm/modules/__pycache__/attention.cpython-311.pyc +0 -0
  27. ldm/modules/__pycache__/attention.cpython-312.pyc +0 -0
  28. ldm/modules/__pycache__/ema.cpython-311.pyc +0 -0
  29. ldm/modules/__pycache__/ema.cpython-312.pyc +0 -0
  30. ldm/modules/__pycache__/ema.cpython-38.pyc +0 -0
  31. ldm/modules/attention.py +341 -0
  32. ldm/modules/diffusionmodules/__init__.py +0 -0
  33. ldm/modules/diffusionmodules/__pycache__/__init__.cpython-311.pyc +0 -0
  34. ldm/modules/diffusionmodules/__pycache__/__init__.cpython-312.pyc +0 -0
  35. ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc +0 -0
  36. ldm/modules/diffusionmodules/__pycache__/model.cpython-311.pyc +0 -0
  37. ldm/modules/diffusionmodules/__pycache__/model.cpython-312.pyc +0 -0
  38. ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc +0 -0
  39. ldm/modules/diffusionmodules/__pycache__/util.cpython-311.pyc +0 -0
  40. ldm/modules/diffusionmodules/__pycache__/util.cpython-312.pyc +0 -0
  41. ldm/modules/diffusionmodules/model.py +860 -0
  42. ldm/modules/diffusionmodules/model_back.py +815 -0
  43. ldm/modules/diffusionmodules/openaimodel.py +788 -0
  44. ldm/modules/diffusionmodules/upscaling.py +81 -0
  45. ldm/modules/diffusionmodules/util.py +270 -0
  46. ldm/modules/distributions/__init__.py +0 -0
  47. ldm/modules/distributions/__pycache__/__init__.cpython-311.pyc +0 -0
  48. ldm/modules/distributions/__pycache__/__init__.cpython-312.pyc +0 -0
  49. ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc +0 -0
  50. ldm/modules/distributions/__pycache__/distributions.cpython-311.pyc +0 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
DEPLOYMENT_GUIDE.md ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Space Deployment Guide
2
+
3
+ This guide will help you deploy your ResShift Super-Resolution model to Hugging Face Spaces.
4
+
5
+ ## Prerequisites
6
+
7
+ 1. Hugging Face account (sign up at https://huggingface.co)
8
+ 2. Git installed on your machine
9
+ 3. Your trained model checkpoint
10
+
11
+ ## Step 1: Create a New Space
12
+
13
+ 1. Go to https://huggingface.co/spaces
14
+ 2. Click **"Create new Space"**
15
+ 3. Fill in the details:
16
+ - **Space name**: e.g., `resshift-super-resolution`
17
+ - **SDK**: Select **"Gradio"**
18
+ - **Hardware**: Choose **"GPU"** (recommended for faster inference)
19
+ - **Visibility**: Public or Private
20
+ 4. Click **"Create Space"**
21
+
22
+ ## Step 2: Clone the Space Repository
23
+
24
+ After creating the space, Hugging Face will provide you with a Git URL. Clone it:
25
+
26
+ ```bash
27
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
28
+ cd YOUR_SPACE_NAME
29
+ ```
30
+
31
+ ## Step 3: Copy Required Files
32
+
33
+ Copy the following files from your project to the Space repository:
34
+
35
+ ### Essential Files:
36
+ ```bash
37
+ # From your DiffusionSR directory
38
+ cp app.py YOUR_SPACE_NAME/
39
+ cp requirements.txt YOUR_SPACE_NAME/
40
+ cp SPACE_README.md YOUR_SPACE_NAME/README.md
41
+
42
+ # Copy source code
43
+ cp -r src/ YOUR_SPACE_NAME/
44
+
45
+ # Copy model checkpoint
46
+ mkdir -p YOUR_SPACE_NAME/checkpoints/ckpts
47
+ cp checkpoints/ckpts/model_3200.pth YOUR_SPACE_NAME/checkpoints/ckpts/
48
+
49
+ # Copy VQGAN weights
50
+ mkdir -p YOUR_SPACE_NAME/pretrained_weights
51
+ cp pretrained_weights/autoencoder_vq_f4.pth YOUR_SPACE_NAME/pretrained_weights/
52
+ ```
53
+
54
+ ### Important Notes:
55
+ - **Model Size**: Checkpoints can be large (200-500MB). Hugging Face Spaces supports files up to 10GB.
56
+ - **Git LFS**: For large files, you may need Git LFS:
57
+ ```bash
58
+ git lfs install
59
+ git lfs track "*.pth"
60
+ git add .gitattributes
61
+ ```
62
+
63
+ ## Step 4: Update app.py (if needed)
64
+
65
+ If your checkpoint path is different, update `app.py`:
66
+
67
+ ```python
68
+ # In app.py, line ~25, update the checkpoint path:
69
+ checkpoint_path = "checkpoints/ckpts/model_3200.pth" # Change to your checkpoint name
70
+ ```
71
+
72
+ ## Step 5: Commit and Push
73
+
74
+ ```bash
75
+ cd YOUR_SPACE_NAME
76
+ git add .
77
+ git commit -m "Initial commit: ResShift Super-Resolution app"
78
+ git push
79
+ ```
80
+
81
+ ## Step 6: Wait for Build
82
+
83
+ Hugging Face will automatically:
84
+ 1. Install dependencies from `requirements.txt`
85
+ 2. Run `app.py`
86
+ 3. Make your app available at: `https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME`
87
+
88
+ The build process usually takes 5-10 minutes.
89
+
90
+ ## Step 7: Test Your App
91
+
92
+ Once the build completes:
93
+ 1. Visit your Space URL
94
+ 2. Upload a test image
95
+ 3. Verify the super-resolution works correctly
96
+
97
+ ## Troubleshooting
98
+
99
+ ### Build Fails
100
+ - Check the **Logs** tab in your Space for error messages
101
+ - Verify all dependencies are in `requirements.txt`
102
+ - Ensure file paths are correct
103
+
104
+ ### Model Not Loading
105
+ - Check that checkpoint path in `app.py` matches your file structure
106
+ - Verify checkpoint file was uploaded correctly
107
+ - Check logs for specific error messages
108
+
109
+ ### Out of Memory
110
+ - Reduce batch size in inference
111
+ - Use CPU instead of GPU (slower but uses less memory)
112
+ - Consider using a smaller model checkpoint
113
+
114
+ ### Slow Inference
115
+ - Enable GPU in Space settings
116
+ - Reduce number of diffusion steps (modify `T` in config)
117
+ - Use AMP (automatic mixed precision)
118
+
119
+ ## Alternative: Upload via Web Interface
120
+
121
+ If you prefer not to use Git:
122
+
123
+ 1. Go to your Space page
124
+ 2. Click **"Files and versions"** tab
125
+ 3. Click **"Add file"** → **"Upload files"**
126
+ 4. Upload all required files
127
+ 5. The Space will rebuild automatically
128
+
129
+ ## Updating Your Space
130
+
131
+ To update your Space with new changes:
132
+
133
+ ```bash
134
+ cd YOUR_SPACE_NAME
135
+ # Make your changes
136
+ git add .
137
+ git commit -m "Update: description of changes"
138
+ git push
139
+ ```
140
+
141
+ ## Sharing Your Space
142
+
143
+ Once deployed, you can:
144
+ - Share the Space URL with others
145
+ - Embed it in websites using iframe
146
+ - Use it via API (if enabled)
147
+
148
+ ## Next Steps
149
+
150
+ 1. **Add Examples**: Add example images to showcase your model
151
+ 2. **Improve UI**: Customize the Gradio interface
152
+ 3. **Add Documentation**: Update README with more details
153
+ 4. **Monitor Usage**: Check Space metrics to see usage
154
+
155
+ ## Support
156
+
157
+ If you encounter issues:
158
+ - Check Hugging Face Spaces documentation: https://huggingface.co/docs/hub/spaces
159
+ - Review Space logs for error messages
160
+ - Ask for help in Hugging Face forums
161
+
FUNCTION_MAPPING.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Function Mapping: Original Implementation → Our Implementation
2
+
3
+ This document maps functions from the original ResShift implementation to our corresponding functions and explains what each function does.
4
+
5
+ ## Core Diffusion Functions
6
+
7
+ ### Forward Process (Noise Addition)
8
+
9
+ | Original Function | Our Implementation | Description |
10
+ |------------------|-------------------|-------------|
11
+ | `GaussianDiffusion.q_sample(x_start, y, t, noise=None)` | `trainer.py: training_step()` (line 434) | **Forward diffusion process**: Adds noise to HR image according to ResShift schedule. Formula: `x_t = x_0 + η_t * (y - x_0) + κ * √η_t * ε` |
12
+ | `GaussianDiffusion.q_mean_variance(x_start, y, t)` | Not directly used | Computes mean and variance of forward process `q(x_t | x_0, y)` |
13
+ | `GaussianDiffusion.q_posterior_mean_variance(x_start, x_t, t)` | `trainer.py: validation()` (line 845) | Computes posterior mean and variance `q(x_{t-1} | x_t, x_0)` for backward sampling. Used in equation (7) |
14
+
15
+ ### Backward Process (Sampling)
16
+
17
+ | Original Function | Our Implementation | Description |
18
+ |------------------|-------------------|-------------|
19
+ | `GaussianDiffusion.p_mean_variance(model, x_t, y, t, ...)` | `trainer.py: validation()` (lines 844-848)<br>`inference.py: inference_single_image()` (lines 251-255)<br>`app.py: super_resolve()` (lines 154-158) | **Computes backward step parameters**: Calculates mean `μ_θ` and variance `Σ_θ` for equation (7). Mean: `μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * f_θ(x_t, y_0, t)`<br>Variance: `Σ_θ = κ² * (η_{t-1}/η_t) * α_t` |
20
+ | `GaussianDiffusion.p_sample(model, x, y, t, ...)` | `trainer.py: validation()` (lines 850-853)<br>`inference.py: inference_single_image()` (lines 257-260)<br>`app.py: super_resolve()` (lines 160-163) | **Single backward sampling step**: Samples `x_{t-1}` from `p(x_{t-1} | x_t, y)` using equation (7). Formula: `x_{t-1} = μ_θ + √Σ_θ * ε` (with nonzero_mask for t > 0) |
21
+ | `GaussianDiffusion.p_sample_loop(y, model, ...)` | `trainer.py: validation()` (lines 822-856)<br>`inference.py: inference_single_image()` (lines 229-263)<br>`app.py: super_resolve()` (lines 133-165) | **Full sampling loop**: Iterates from t = T-1 down to t = 0, calling `p_sample` at each step. Returns final denoised sample |
22
+ | `GaussianDiffusion.p_sample_loop_progressive(y, model, ...)` | Same as above (we don't use progressive) | Same as `p_sample_loop` but yields intermediate samples at each timestep |
23
+ | `GaussianDiffusion.prior_sample(y, noise=None)` | `trainer.py: validation()` (lines 813-815)<br>`inference.py: inference_single_image()` (lines 223-226)<br>`app.py: super_resolve()` (lines 128-130) | **Initializes x_T for sampling**: Generates starting point from prior distribution. Formula: `x_T = y + κ * √η_T * noise` (starts from LR + noise) |
24
+
25
+ ### Input Scaling
26
+
27
+ | Original Function | Our Implementation | Description |
28
+ |------------------|-------------------|-------------|
29
+ | `GaussianDiffusion._scale_input(inputs, t)` | `trainer.py: _scale_input()` (lines 367-389)<br>`inference.py: _scale_input()` (lines 151-173)<br>`app.py: _scale_input()` (lines 50-72) | **Normalizes input variance**: Scales input `x_t` to normalize variance across timesteps for training stability. Formula: `x_scaled = x_t / std` where `std = √(η_t * κ² + 1)` for latent space |
30
+
31
+ ### Training Loss
32
+
33
+ | Original Function | Our Implementation | Description |
34
+ |------------------|-------------------|-------------|
35
+ | `GaussianDiffusion.training_losses(model, x_start, y, t, ...)` | `trainer.py: training_step()` (lines 392-493) | **Computes training loss**: Encodes HR/LR to latent, adds noise via `q_sample`, predicts x0, computes MSE loss. Our implementation: predicts x0 directly (ModelMeanType.START_X) |
36
+
37
+ ### Autoencoder Functions
38
+
39
+ | Original Function | Our Implementation | Description |
40
+ |------------------|-------------------|-------------|
41
+ | `GaussianDiffusion.encode_first_stage(y, first_stage_model, up_sample=False)` | `autoencoder.py: VQGANWrapper.encode()`<br>`data.py: SRDatasetOnTheFly.__getitem__()` (line 160) | **Encodes image to latent**: Encodes pixel-space image to latent space using VQGAN. If `up_sample=True`, upsamples LR before encoding |
42
+ | `GaussianDiffusion.decode_first_stage(z_sample, first_stage_model, ...)` | `autoencoder.py: VQGANWrapper.decode()`<br>`trainer.py: validation()` (line 861)<br>`inference.py: inference_single_image()` (line 269) | **Decodes latent to image**: Decodes latent-space tensor back to pixel space using VQGAN decoder |
43
+
44
+ ## Trainer Functions
45
+
46
+ ### Original Trainer (`original_trainer.py`)
47
+
48
+ | Original Function | Our Implementation | Description |
49
+ |------------------|-------------------|-------------|
50
+ | `Trainer.__init__(configs, ...)` | `trainer.py: Trainer.__init__()` (lines 36-75) | **Initializes trainer**: Sets up device, checkpoint directory, noise schedule, loss function, WandB |
51
+ | `Trainer.setup_seed(seed=None)` | `trainer.py: setup_seed()` (lines 76-95) | **Sets random seeds**: Ensures reproducibility by setting seeds for random, numpy, torch, and CUDA |
52
+ | `Trainer.build_model()` | `trainer.py: build_model()` (lines 212-267) | **Builds model and autoencoder**: Initializes FullUNET model, optionally compiles it, loads VQGAN autoencoder, initializes LPIPS metric |
53
+ | `Trainer.setup_optimization()` | `trainer.py: setup_optimization()` (lines 141-211) | **Sets up optimizer and scheduler**: Initializes AdamW optimizer, AMP scaler (if enabled), CosineAnnealingLR scheduler (if enabled) |
54
+ | `Trainer.build_dataloader()` | `trainer.py: build_dataloader()` (lines 283-344) | **Builds data loaders**: Creates train and validation DataLoaders, wraps train loader to cycle infinitely |
55
+ | `Trainer.training_losses()` | `trainer.py: training_step()` (lines 392-493) | **Training step**: Implements micro-batching, adds noise, forward pass, loss computation, backward step, gradient clipping, optimizer step |
56
+ | `Trainer.validation(phase='val')` | `trainer.py: validation()` (lines 748-963) | **Validation loop**: Runs full diffusion sampling loop, decodes results, computes PSNR/SSIM/LPIPS metrics, logs to WandB |
57
+ | `Trainer.adjust_lr(current_iters=None)` | `trainer.py: adjust_lr()` (lines 495-519) | **Learning rate scheduling**: Implements linear warmup, then cosine annealing (if enabled) |
58
+ | `Trainer.save_ckpt()` | `trainer.py: save_ckpt()` (lines 520-562) | **Saves checkpoint**: Saves model, optimizer, AMP scaler, LR scheduler states, current iteration |
59
+ | `Trainer.resume_from_ckpt(ckpt_path)` | `trainer.py: resume_from_ckpt()` (lines 563-670) | **Resumes from checkpoint**: Loads model, optimizer, scaler, scheduler states, restores iteration count and LR schedule |
60
+ | `Trainer.log_step_train(...)` | `trainer.py: log_step_train()` (lines 671-747) | **Logs training metrics**: Logs loss, learning rate, images (HR, LR, noisy input, prediction) to WandB at specified frequencies |
61
+ | `Trainer.reload_ema_model()` | `trainer.py: validation()` (line 754) | **Loads EMA model**: Uses EMA model for validation if `use_ema_val=True` |
62
+
63
+ ## Inference Functions
64
+
65
+ ### Original Sampler (`original_sampler.py`)
66
+
67
+ | Original Function | Our Implementation | Description |
68
+ |------------------|-------------------|-------------|
69
+ | `ResShiftSampler.sample_func(y0, noise_repeat=False, mask=False)` | `inference.py: inference_single_image()` (lines 175-274)<br>`app.py: super_resolve()` (lines 93-165) | **Single image inference**: Encodes LR to latent, runs full diffusion sampling loop, decodes to pixel space |
70
+ | `ResShiftSampler.inference(in_path, out_path, ...)` | `inference.py: main()` (lines 385-509) | **Batch inference**: Processes single image or directory of images, handles chopping for large images |
71
+ | `ResShiftSampler.build_model()` | `inference.py: load_model()` (lines 322-384) | **Loads model and autoencoder**: Loads checkpoint, handles compiled model checkpoints (strips `_orig_mod.` prefix), loads EMA if specified |
72
+
73
+ ## Key Differences and Notes
74
+
75
+ ### 1. **Model Prediction Type**
76
+ - **Original**: Supports multiple prediction types (START_X, EPSILON, RESIDUAL, EPSILON_SCALE)
77
+ - **Our Implementation**: Only uses START_X (predicts x0 directly), matching ResShift paper
78
+
79
+ ### 2. **Sampling Initialization**
80
+ - **Original**: Uses `prior_sample(y, noise)` → `x_T = y + κ * √η_T * noise` (starts from LR + noise)
81
+ - **Our Implementation**: Same approach in inference and validation (fixed in validation to match original)
82
+
83
+ ### 3. **Backward Equation (Equation 7)**
84
+ - **Original**: Uses `p_mean_variance()` → computes `μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * x0_pred` and `Σ_θ = κ² * (η_{t-1}/η_t) * α_t`
85
+ - **Our Implementation**: Same equation, implemented directly in sampling loops
86
+
87
+ ### 4. **Input Scaling**
88
+ - **Original**: `_scale_input()` normalizes variance: `x_scaled = x_t / √(η_t * κ² + 1)` for latent space
89
+ - **Our Implementation**: Same formula, applied in training and inference
90
+
91
+ ### 5. **Training Loss**
92
+ - **Original**: Supports weighted MSE based on posterior variance
93
+ - **Our Implementation**: Uses simple MSE loss (predicts x0, compares with HR latent)
94
+
95
+ ### 6. **Validation**
96
+ - **Original**: Uses `p_sample_loop_progressive()` with EMA model, computes PSNR/LPIPS
97
+ - **Our Implementation**: Same approach, also computes SSIM, uses EMA if `use_ema_val=True`
98
+
99
+ ### 7. **Checkpoint Handling**
100
+ - **Original**: Standard checkpoint loading
101
+ - **Our Implementation**: Handles compiled model checkpoints (strips `_orig_mod.` prefix) for `torch.compile()` compatibility
102
+
103
+ ## Function Call Flow
104
+
105
+ ### Training Flow
106
+ ```
107
+ train.py: train()
108
+ → Trainer.__init__()
109
+ → Trainer.build_model()
110
+ → Trainer.setup_optimization()
111
+ → Trainer.build_dataloader()
112
+ → Loop:
113
+ → Trainer.training_step() # q_sample + forward + loss + backward
114
+ → Trainer.adjust_lr()
115
+ → Trainer.validation() # p_sample_loop
116
+ → Trainer.log_step_train()
117
+ → Trainer.save_ckpt()
118
+ ```
119
+
120
+ ### Inference Flow
121
+ ```
122
+ inference.py: main()
123
+ → load_model() # Load checkpoint
124
+ → get_vqgan() # Load autoencoder
125
+ → inference_single_image()
126
+ → autoencoder.encode() # LR → latent
127
+ → prior_sample() # Initialize x_T
128
+ → Loop: p_sample() # Denoise T-1 → 0
129
+ → autoencoder.decode() # Latent → pixel
130
+ ```
131
+
132
+ ## Summary
133
+
134
+ Our implementation closely follows the original ResShift implementation, with the following key mappings:
135
+ - **Forward process**: `q_sample` → `training_step()` noise addition
136
+ - **Backward process**: `p_sample` → sampling loops in `validation()` and `inference_single_image()`
137
+ - **Training**: `training_losses` → `training_step()`
138
+ - **Autoencoder**: `encode_first_stage`/`decode_first_stage` → `VQGANWrapper.encode()`/`decode()`
139
+ - **Input scaling**: `_scale_input` → `_scale_input()` (same name, same logic)
140
+
141
+ All core diffusion equations (forward process, backward equation 7, input scaling) match the original implementation.
142
+
README copy.md ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiffusionSR
2
+
3
+ A **from-scratch implementation** of the [ResShift](https://arxiv.org/abs/2307.12348) paper: an efficient diffusion-based super-resolution model that uses a U-Net architecture with Swin Transformer blocks to enhance low-resolution images. This implementation combines the power of diffusion models with transformer-based attention mechanisms for high-quality image super-resolution.
4
+
5
+ ## Overview
6
+
7
+ This project is a complete from-scratch implementation of ResShift, a diffusion model for single image super-resolution (SISR) that efficiently reduces the number of diffusion steps required by shifting the residual between high-resolution and low-resolution images. The model architecture consists of:
8
+
9
+ - **Encoder**: 4-stage encoder with residual blocks and time embeddings
10
+ - **Bottleneck**: Swin Transformer blocks for global feature modeling
11
+ - **Decoder**: 4-stage decoder with skip connections from the encoder
12
+ - **Noise Schedule**: ResShift schedule (15 timesteps) for the diffusion process
13
+
14
+ ## Features
15
+
16
+ - **ResShift Implementation**: Complete from-scratch implementation of the ResShift paper
17
+ - **Efficient Diffusion**: Residual shifting mechanism reduces required diffusion steps
18
+ - **U-Net Architecture**: Encoder-decoder structure with skip connections
19
+ - **Swin Transformer**: Window-based attention mechanism in the bottleneck
20
+ - **Time Conditioning**: Sinusoidal time embeddings for diffusion timesteps
21
+ - **DIV2K Dataset**: Trained on DIV2K high-quality image dataset
22
+ - **Comprehensive Evaluation**: Metrics include PSNR, SSIM, and LPIPS
23
+
24
+ ## Requirements
25
+
26
+ - Python >= 3.11
27
+ - PyTorch >= 2.9.1
28
+ - [uv](https://github.com/astral-sh/uv) (Python package manager)
29
+
30
+ ## Installation
31
+
32
+ ### 1. Clone the Repository
33
+
34
+ ```bash
35
+ git clone <repository-url>
36
+ cd DiffusionSR
37
+ ```
38
+
39
+ ### 2. Install uv (if not already installed)
40
+
41
+ ```bash
42
+ # On macOS and Linux
43
+ curl -LsSf https://astral.sh/uv/install.sh | sh
44
+
45
+ # Or using pip
46
+ pip install uv
47
+ ```
48
+
49
+ ### 3. Create Virtual Environment and Install Dependencies
50
+
51
+ ```bash
52
+ # Create virtual environment and install dependencies
53
+ uv venv
54
+
55
+ # Activate the virtual environment
56
+ # On macOS/Linux:
57
+ source .venv/bin/activate
58
+
59
+ # On Windows:
60
+ # .venv\Scripts\activate
61
+
62
+ # Install project dependencies
63
+ uv pip install -e .
64
+ ```
65
+
66
+ Alternatively, you can use uv's sync command:
67
+
68
+ ```bash
69
+ uv sync
70
+ ```
71
+
72
+ ## Dataset Setup
73
+
74
+ The model expects the DIV2K dataset in the following structure:
75
+
76
+ ```
77
+ data/
78
+ ├── DIV2K_train_HR/ # High-resolution training images
79
+ └── DIV2K_train_LR_bicubic/
80
+ └── X4/ # Low-resolution images (4x downsampled)
81
+ ```
82
+
83
+ ### Download DIV2K Dataset
84
+
85
+ 1. Download the DIV2K dataset from the [official website](https://data.vision.ee.ethz.ch/cvl/DIV2K/)
86
+ 2. Extract the files to the `data/` directory
87
+ 3. Ensure the directory structure matches the above
88
+
89
+ **Note**: Update the paths in `src/data.py` (lines 75-76) to match your dataset location:
90
+
91
+ ```python
92
+ train_dataset = SRDataset(
93
+ dir_HR = 'path/to/DIV2K_train_HR',
94
+ dir_LR = 'path/to/DIV2K_train_LR_bicubic/X4',
95
+ scale=4,
96
+ patch_size=256
97
+ )
98
+ ```
99
+
100
+ ## Usage
101
+
102
+ ### Training
103
+
104
+ To train the model, run:
105
+
106
+ ```bash
107
+ python src/train.py
108
+ ```
109
+
110
+ The training script will:
111
+ - Load the dataset using the `SRDataset` class
112
+ - Initialize the `FullUNET` model
113
+ - Train using the ResShift noise schedule
114
+ - Save training progress and loss values
115
+
116
+ ### Training Configuration
117
+
118
+ Current training parameters (in `src/train.py`):
119
+ - **Batch size**: 4
120
+ - **Learning rate**: 1e-4
121
+ - **Optimizer**: Adam (betas: 0.9, 0.999)
122
+ - **Loss function**: MSE Loss
123
+ - **Gradient clipping**: 1.0
124
+ - **Training steps**: 150
125
+ - **Scale factor**: 4x
126
+ - **Patch size**: 256x256
127
+
128
+ You can modify these parameters directly in `src/train.py` to suit your needs.
129
+
130
+ ### Evaluation
131
+
132
+ The model performance is evaluated using the following metrics:
133
+
134
+ - **PSNR (Peak Signal-to-Noise Ratio)**: Measures the ratio between the maximum possible power of a signal and the power of corrupting noise. Higher PSNR values indicate better image quality reconstruction.
135
+
136
+ - **SSIM (Structural Similarity Index Measure)**: Assesses the similarity between two images based on luminance, contrast, and structure. SSIM values range from -1 to 1, with higher values (closer to 1) indicating greater similarity to the ground truth.
137
+
138
+ - **LPIPS (Learned Perceptual Image Patch Similarity)**: Evaluates perceptual similarity between images using deep network features. Lower LPIPS values indicate images that are more perceptually similar to the reference image.
139
+
140
+ To run evaluation (once implemented), use:
141
+
142
+ ```bash
143
+ python src/test.py
144
+ ```
145
+
146
+ ## Project Structure
147
+
148
+ ```
149
+ DiffusionSR/
150
+ ├── data/ # Dataset directory (not tracked in git)
151
+ │ ├── DIV2K_train_HR/
152
+ │ └── DIV2K_train_LR_bicubic/
153
+ ├── src/
154
+ │ ├── config.py # Configuration file
155
+ │ ├── data.py # Dataset class and data loading
156
+ │ ├── model.py # U-Net model architecture
157
+ │ ├── noiseControl.py # ResShift noise schedule
158
+ │ ├── train.py # Training script
159
+ │ └── test.py # Testing script (to be implemented)
160
+ ├── pyproject.toml # Project dependencies and metadata
161
+ ├── uv.lock # Locked dependency versions
162
+ └── README.md # This file
163
+ ```
164
+
165
+ ## Model Architecture
166
+
167
+ ### Encoder
168
+ - **Initial Conv**: 3 → 64 channels
169
+ - **Stage 1**: 64 → 128 channels, 256×256 → 128×128
170
+ - **Stage 2**: 128 → 256 channels, 128×128 → 64×64
171
+ - **Stage 3**: 256 → 512 channels, 64×64 → 32×32
172
+ - **Stage 4**: 512 channels (no downsampling)
173
+
174
+ ### Bottleneck
175
+ - Residual blocks with Swin Transformer blocks
176
+ - Window size: 7×7
177
+ - Shifted window attention for global context
178
+
179
+ ### Decoder
180
+ - **Stage 1**: 512 → 256 channels, 32×32 → 64×64
181
+ - **Stage 2**: 256 → 128 channels, 64×64 → 128×128
182
+ - **Stage 3**: 128 → 64 channels, 128×128 → 256×256
183
+ - **Stage 4**: 64 → 64 channels
184
+ - **Final Conv**: 64 → 3 channels (RGB output)
185
+
186
+ ## Key Components
187
+
188
+ ### ResShift Noise Schedule
189
+ The model implements the ResShift noise schedule as described in the original paper, defined in `src/noiseControl.py`:
190
+ - 15 timesteps (0-14)
191
+ - Parameters: `eta1=0.001`, `etaT=0.999`, `p=0.8`
192
+ - Efficiently shifts the residual between HR and LR images during the diffusion process
193
+
194
+ ### Time Embeddings
195
+ Sinusoidal embeddings are used to condition the model on diffusion timesteps, similar to positional encodings in transformers.
196
+
197
+ ### Data Augmentation
198
+ The dataset includes:
199
+ - Random cropping (aligned between HR and LR)
200
+ - Random horizontal/vertical flips
201
+ - Random 180° rotation
202
+
203
+ ## Development
204
+
205
+ ### Adding New Features
206
+
207
+ 1. Model modifications: Edit `src/model.py`
208
+ 2. Training changes: Modify `src/train.py`
209
+ 3. Data pipeline: Update `src/data.py`
210
+ 4. Configuration: Add settings to `src/config.py`
211
+
212
+ ## License
213
+
214
+ [Add your license here]
215
+
216
+ ## Citation
217
+
218
+ If you use this code in your research, please cite the original ResShift paper:
219
+
220
+ ```bibtex
221
+ @article{yue2023resshift,
222
+ title={ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting},
223
+ author={Yue, Zongsheng and Wang, Jianyi and Loy, Chen Change},
224
+ journal={arXiv preprint arXiv:2307.12348},
225
+ year={2023}
226
+ }
227
+ ```
228
+
229
+ ## Acknowledgments
230
+
231
+ - **ResShift Authors**: Zongsheng Yue, Jianyi Wang, and Chen Change Loy for their foundational work on efficient diffusion-based super-resolution
232
+ - DIV2K dataset providers
233
+ - PyTorch community
234
+ - Swin Transformer architecture inspiration
235
+
SPACE_README.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ResShift Super-Resolution
3
+ emoji: 🖼️
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.0.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # ResShift Super-Resolution
14
+
15
+ Super-resolution using ResShift diffusion model. Upload a low-resolution image to get an enhanced, super-resolved version.
16
+
17
+ ## Features
18
+
19
+ - 4x super-resolution using diffusion model
20
+ - Works in latent space for efficient processing
21
+ - Full diffusion sampling loop (15 steps)
22
+ - Real-time inference with Gradio interface
23
+
24
+ ## Usage
25
+
26
+ 1. Upload a low-resolution image
27
+ 2. Click "Super-Resolve" or wait for automatic processing
28
+ 3. Download the super-resolved output
29
+
30
+ ## Model
31
+
32
+ The model is trained on DIV2K dataset and uses VQGAN for latent space encoding/decoding.
33
+
34
+ ## Technical Details
35
+
36
+ - **Architecture**: U-Net with Swin Transformer blocks
37
+ - **Latent Space**: 64x64 (encoded from 256x256 pixel space)
38
+ - **Diffusion Steps**: 15 timesteps
39
+ - **Scale Factor**: 4x
40
+
41
+ ## Citation
42
+
43
+ If you use this model, please cite the ResShift paper.
44
+
__pycache__/app.cpython-311.pyc ADDED
Binary file (9.6 kB). View file
 
app.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio app for ResShift Super-Resolution
3
+ Hosted on Hugging Face Spaces
4
+ """
5
+ import gradio as gr
6
+ import torch
7
+ from PIL import Image
8
+ import torchvision.transforms.functional as TF
9
+ from pathlib import Path
10
+ import sys
11
+
12
+ # Add src to path
13
+ sys.path.insert(0, str(Path(__file__).parent / "src"))
14
+
15
+ from model import FullUNET
16
+ from autoencoder import get_vqgan
17
+ from noiseControl import resshift_schedule
18
+ from config import device, T, k, normalize_input, latent_flag, gt_size
19
+
20
+ # Global variables for loaded models
21
+ model = None
22
+ autoencoder = None
23
+ eta_schedule = None
24
+
25
+
26
+ def load_models():
27
+ """Load models on startup."""
28
+ global model, autoencoder, eta_schedule
29
+
30
+ print("Loading models...")
31
+
32
+ # Load model checkpoint
33
+ checkpoint_path = "checkpoints/ckpts/model_3200.pth" # Update with your checkpoint path
34
+ if not Path(checkpoint_path).exists():
35
+ # Try to find any checkpoint
36
+ ckpt_dir = Path("checkpoints/ckpts")
37
+ if ckpt_dir.exists():
38
+ checkpoints = list(ckpt_dir.glob("model_*.pth"))
39
+ if checkpoints:
40
+ checkpoint_path = str(checkpoints[-1]) # Use latest
41
+ print(f"Using checkpoint: {checkpoint_path}")
42
+ else:
43
+ raise FileNotFoundError("No model checkpoint found!")
44
+ else:
45
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
46
+
47
+ model = FullUNET()
48
+ model = model.to(device)
49
+
50
+ ckpt = torch.load(checkpoint_path, map_location=device)
51
+ if 'state_dict' in ckpt:
52
+ state_dict = ckpt['state_dict']
53
+ else:
54
+ state_dict = ckpt
55
+
56
+ # Handle compiled model checkpoints
57
+ if any(key.startswith('_orig_mod.') for key in state_dict.keys()):
58
+ new_state_dict = {}
59
+ for key, val in state_dict.items():
60
+ if key.startswith('_orig_mod.'):
61
+ new_state_dict[key[10:]] = val
62
+ else:
63
+ new_state_dict[key] = val
64
+ state_dict = new_state_dict
65
+
66
+ model.load_state_dict(state_dict)
67
+ model.eval()
68
+ print("✓ Model loaded")
69
+
70
+ # Load VQGAN autoencoder
71
+ autoencoder = get_vqgan()
72
+ print("✓ VQGAN autoencoder loaded")
73
+
74
+ # Initialize noise schedule
75
+ eta_schedule = resshift_schedule().to(device)
76
+ eta_schedule = eta_schedule[:, None, None, None]
77
+ print("✓ Noise schedule initialized")
78
+
79
+ return "Models loaded successfully!"
80
+
81
+
82
+ def _scale_input(x_t, t, eta_schedule, k, normalize_input, latent_flag):
83
+ """Scale input based on timestep."""
84
+ if normalize_input and latent_flag:
85
+ eta_t = eta_schedule[t]
86
+ std = torch.sqrt(eta_t * k**2 + 1)
87
+ x_t_scaled = x_t / std
88
+ else:
89
+ x_t_scaled = x_t
90
+ return x_t_scaled
91
+
92
+
93
+ def super_resolve(input_image):
94
+ """
95
+ Perform super-resolution on input image.
96
+
97
+ Args:
98
+ input_image: PIL Image or numpy array
99
+
100
+ Returns:
101
+ PIL Image of super-resolved output
102
+ """
103
+ if input_image is None:
104
+ return None
105
+
106
+ if model is None or autoencoder is None:
107
+ return None
108
+
109
+ try:
110
+ # Convert to PIL Image if needed
111
+ if isinstance(input_image, Image.Image):
112
+ img = input_image
113
+ else:
114
+ img = Image.fromarray(input_image)
115
+
116
+ # Resize to target size (256x256)
117
+ img = img.resize((gt_size, gt_size), Image.BICUBIC)
118
+
119
+ # Convert to tensor
120
+ img_tensor = TF.to_tensor(img).unsqueeze(0).to(device) # (1, 3, 256, 256)
121
+
122
+ # Run inference
123
+ with torch.no_grad():
124
+ # Encode to latent space
125
+ lr_latent = autoencoder.encode(img_tensor) # (1, 3, 64, 64)
126
+
127
+ # Initialize x_t at maximum timestep
128
+ epsilon_init = torch.randn_like(lr_latent)
129
+ eta_max = eta_schedule[T - 1]
130
+ x_t = lr_latent + k * torch.sqrt(eta_max) * epsilon_init
131
+
132
+ # Full diffusion sampling loop
133
+ for t_step in range(T - 1, -1, -1):
134
+ t = torch.full((lr_latent.shape[0],), t_step, device=device, dtype=torch.long)
135
+
136
+ # Scale input
137
+ x_t_scaled = _scale_input(x_t, t, eta_schedule, k, normalize_input, latent_flag)
138
+
139
+ # Predict x0
140
+ x0_pred = model(x_t_scaled, t, lq=lr_latent)
141
+
142
+ # Compute x_{t-1} using equation (7)
143
+ if t_step > 0:
144
+ # Equation (7) from ResShift paper:
145
+ # μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * f_θ(x_t, y_0, t)
146
+ # Σ_θ = κ² * (η_{t-1}/η_t) * α_t
147
+ # x_{t-1} = μ_θ + sqrt(Σ_θ) * ε
148
+ eta_t = eta_schedule[t_step]
149
+ eta_t_minus_1 = eta_schedule[t_step - 1]
150
+
151
+ # Compute alpha_t = η_t - η_{t-1}
152
+ alpha_t = eta_t - eta_t_minus_1
153
+
154
+ # Compute mean: μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * x0_pred
155
+ mean = (eta_t_minus_1 / eta_t) * x_t + (alpha_t / eta_t) * x0_pred
156
+
157
+ # Compute variance: Σ_θ = κ² * (η_{t-1}/η_t) * α_t
158
+ variance = k**2 * (eta_t_minus_1 / eta_t) * alpha_t
159
+
160
+ # Sample: x_{t-1} = μ_θ + sqrt(Σ_θ) * ε
161
+ noise = torch.randn_like(x_t)
162
+ nonzero_mask = torch.tensor(1.0 if t_step > 0 else 0.0, device=x_t.device).view(-1, *([1] * (len(x_t.shape) - 1)))
163
+ x_t = mean + nonzero_mask * torch.sqrt(variance) * noise
164
+ else:
165
+ x_t = x0_pred
166
+
167
+ # Decode back to pixel space
168
+ sr_latent = x_t
169
+ sr_image = autoencoder.decode(sr_latent) # (1, 3, 256, 256)
170
+ sr_image = sr_image.clamp(0, 1)
171
+
172
+ # Convert to PIL Image
173
+ sr_pil = TF.to_pil_image(sr_image.squeeze(0).cpu())
174
+
175
+ return sr_pil
176
+
177
+ except Exception as e:
178
+ print(f"Error during inference: {str(e)}")
179
+ import traceback
180
+ traceback.print_exc()
181
+ return None
182
+
183
+
184
+ # Create Gradio interface
185
+ with gr.Blocks(title="ResShift Super-Resolution") as demo:
186
+ gr.Markdown(
187
+ """
188
+ # ResShift Super-Resolution
189
+
190
+ Upload a low-resolution image to get a super-resolved version using ResShift diffusion model.
191
+
192
+ **Note**: The model performs 4x super-resolution in latent space (256x256 → 256x256 pixel space, but with enhanced quality).
193
+ """
194
+ )
195
+
196
+ with gr.Row():
197
+ with gr.Column():
198
+ input_image = gr.Image(
199
+ label="Input Image (Low Resolution)",
200
+ type="pil",
201
+ height=300
202
+ )
203
+ submit_btn = gr.Button("Super-Resolve", variant="primary")
204
+
205
+ with gr.Column():
206
+ output_image = gr.Image(
207
+ label="Super-Resolved Output",
208
+ type="pil",
209
+ height=300
210
+ )
211
+
212
+ status = gr.Textbox(label="Status", value="Loading models...", interactive=False)
213
+
214
+ # Load models on startup
215
+ demo.load(
216
+ fn=load_models,
217
+ outputs=status,
218
+ show_progress=True
219
+ )
220
+
221
+ # Process on button click
222
+ submit_btn.click(
223
+ fn=super_resolve,
224
+ inputs=input_image,
225
+ outputs=output_image,
226
+ show_progress=True
227
+ )
228
+
229
+ # Also process on image upload
230
+ input_image.change(
231
+ fn=super_resolve,
232
+ inputs=input_image,
233
+ outputs=output_image,
234
+ show_progress=True
235
+ )
236
+
237
+
238
+ if __name__ == "__main__":
239
+ demo.launch(share=True)
240
+
ldm/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ldm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # ldm package
2
+
ldm/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (165 Bytes). View file
 
ldm/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (161 Bytes). View file
 
ldm/__pycache__/util.cpython-311.pyc ADDED
Binary file (11.7 kB). View file
 
ldm/__pycache__/util.cpython-312.pyc ADDED
Binary file (10.2 kB). View file
 
ldm/__pycache__/util.cpython-38.pyc ADDED
Binary file (6.16 kB). View file
 
ldm/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # ldm.models package
2
+
ldm/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (172 Bytes). View file
 
ldm/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (168 Bytes). View file
 
ldm/models/__pycache__/autoencoder.cpython-311.pyc ADDED
Binary file (8.52 kB). View file
 
ldm/models/__pycache__/autoencoder.cpython-312.pyc ADDED
Binary file (7.31 kB). View file
 
ldm/models/__pycache__/autoencoder.cpython-38.pyc ADDED
Binary file (4.9 kB). View file
 
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from functools import partial
6
+ from contextlib import contextmanager
7
+
8
+ import loralib as lora
9
+
10
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
11
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
12
+ from ldm.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
13
+
14
+ from ldm.util import instantiate_from_config
15
+ from ldm.modules.ema import LitEma
16
+
17
+ class VQModelTorch(nn.Module):
18
+ def __init__(self,
19
+ ddconfig,
20
+ n_embed,
21
+ embed_dim,
22
+ remap=None,
23
+ rank=8, # rank for lora
24
+ lora_alpha=1.0,
25
+ lora_tune_decoder=False,
26
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
27
+ ):
28
+ super().__init__()
29
+ if lora_tune_decoder:
30
+ conv_layer = partial(lora.Conv2d, r=rank, lora_alpha=lora_alpha)
31
+ else:
32
+ conv_layer = nn.Conv2d
33
+
34
+ self.encoder = Encoder(**ddconfig)
35
+ self.decoder = Decoder(rank=rank, lora_alpha=lora_alpha, lora_tune=lora_tune_decoder, **ddconfig)
36
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
37
+ remap=remap, sane_index_shape=sane_index_shape)
38
+ self.quant_conv = nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
39
+ self.post_quant_conv = conv_layer(embed_dim, ddconfig["z_channels"], 1)
40
+
41
+ def encode(self, x):
42
+ h = self.encoder(x)
43
+ h = self.quant_conv(h)
44
+ return h
45
+
46
+ def decode(self, h, force_not_quantize=False):
47
+ if not force_not_quantize:
48
+ quant, emb_loss, info = self.quantize(h)
49
+ else:
50
+ quant = h
51
+ quant = self.post_quant_conv(quant)
52
+ dec = self.decoder(quant)
53
+ return dec
54
+
55
+ def decode_code(self, code_b):
56
+ quant_b = self.quantize.embed_code(code_b)
57
+ dec = self.decode(quant_b, force_not_quantize=True)
58
+ return dec
59
+
60
+ def forward(self, input, force_not_quantize=False):
61
+ h = self.encode(input)
62
+ dec = self.decode(h, force_not_quantize)
63
+ return dec
64
+
65
+ class AutoencoderKLTorch(torch.nn.Module):
66
+ def __init__(self,
67
+ ddconfig,
68
+ embed_dim,
69
+ ):
70
+ super().__init__()
71
+ self.encoder = Encoder(**ddconfig)
72
+ self.decoder = Decoder(**ddconfig)
73
+ assert ddconfig["double_z"]
74
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
75
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
76
+ self.embed_dim = embed_dim
77
+
78
+ def encode(self, x, sample_posterior=True, return_moments=False):
79
+ h = self.encoder(x)
80
+ moments = self.quant_conv(h)
81
+ posterior = DiagonalGaussianDistribution(moments)
82
+ if sample_posterior:
83
+ z = posterior.sample()
84
+ else:
85
+ z = posterior.mode()
86
+ if return_moments:
87
+ return z, moments
88
+ else:
89
+ return z
90
+
91
+ def decode(self, z):
92
+ z = self.post_quant_conv(z)
93
+ dec = self.decoder(z)
94
+ return dec
95
+
96
+ def forward(self, input, sample_posterior=True):
97
+ z = self.encode(input, sample_posterior, return_moments=False)
98
+ dec = self.decode(z)
99
+ return dec
100
+
101
+ class EncoderKLTorch(torch.nn.Module):
102
+ def __init__(self,
103
+ ddconfig,
104
+ embed_dim,
105
+ ):
106
+ super().__init__()
107
+ self.encoder = Encoder(**ddconfig)
108
+ assert ddconfig["double_z"]
109
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
110
+ self.embed_dim = embed_dim
111
+
112
+ def encode(self, x, sample_posterior=True, return_moments=False):
113
+ h = self.encoder(x)
114
+ moments = self.quant_conv(h)
115
+ posterior = DiagonalGaussianDistribution(moments)
116
+ if sample_posterior:
117
+ z = posterior.sample()
118
+ else:
119
+ z = posterior.mode()
120
+ if return_moments:
121
+ return z, moments
122
+ else:
123
+ return z
124
+ def forward(self, x, sample_posterior=True, return_moments=False):
125
+ return self.encode(x, sample_posterior, return_moments)
126
+
127
+ class IdentityFirstStage(torch.nn.Module):
128
+ def __init__(self, *args, vq_interface=False, **kwargs):
129
+ self.vq_interface = vq_interface
130
+ super().__init__()
131
+
132
+ def encode(self, x, *args, **kwargs):
133
+ return x
134
+
135
+ def decode(self, x, *args, **kwargs):
136
+ return x
137
+
138
+ def quantize(self, x, *args, **kwargs):
139
+ if self.vq_interface:
140
+ return x, None, [None, None, None]
141
+ return x
142
+
143
+ def forward(self, x, *args, **kwargs):
144
+ return x
145
+
ldm/modules/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ldm/modules/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # ldm.modules package
2
+
ldm/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (173 Bytes). View file
 
ldm/modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (169 Bytes). View file
 
ldm/modules/__pycache__/attention.cpython-311.pyc ADDED
Binary file (20.3 kB). View file
 
ldm/modules/__pycache__/attention.cpython-312.pyc ADDED
Binary file (17.7 kB). View file
 
ldm/modules/__pycache__/ema.cpython-311.pyc ADDED
Binary file (5.83 kB). View file
 
ldm/modules/__pycache__/ema.cpython-312.pyc ADDED
Binary file (5.19 kB). View file
 
ldm/modules/__pycache__/ema.cpython-38.pyc ADDED
Binary file (3.19 kB). View file
 
ldm/modules/attention.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+ from typing import Optional, Any
8
+
9
+ from ldm.modules.diffusionmodules.util import checkpoint
10
+
11
+
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+ XFORMERS_IS_AVAILBLE = True
16
+ except:
17
+ XFORMERS_IS_AVAILBLE = False
18
+
19
+ # CrossAttn precision handling
20
+ import os
21
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
22
+
23
+ def exists(val):
24
+ return val is not None
25
+
26
+
27
+ def uniq(arr):
28
+ return{el: True for el in arr}.keys()
29
+
30
+
31
+ def default(val, d):
32
+ if exists(val):
33
+ return val
34
+ return d() if isfunction(d) else d
35
+
36
+
37
+ def max_neg_value(t):
38
+ return -torch.finfo(t.dtype).max
39
+
40
+
41
+ def init_(tensor):
42
+ dim = tensor.shape[-1]
43
+ std = 1 / math.sqrt(dim)
44
+ tensor.uniform_(-std, std)
45
+ return tensor
46
+
47
+
48
+ # feedforward
49
+ class GEGLU(nn.Module):
50
+ def __init__(self, dim_in, dim_out):
51
+ super().__init__()
52
+ self.proj = nn.Linear(dim_in, dim_out * 2)
53
+
54
+ def forward(self, x):
55
+ x, gate = self.proj(x).chunk(2, dim=-1)
56
+ return x * F.gelu(gate)
57
+
58
+
59
+ class FeedForward(nn.Module):
60
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
61
+ super().__init__()
62
+ inner_dim = int(dim * mult)
63
+ dim_out = default(dim_out, dim)
64
+ project_in = nn.Sequential(
65
+ nn.Linear(dim, inner_dim),
66
+ nn.GELU()
67
+ ) if not glu else GEGLU(dim, inner_dim)
68
+
69
+ self.net = nn.Sequential(
70
+ project_in,
71
+ nn.Dropout(dropout),
72
+ nn.Linear(inner_dim, dim_out)
73
+ )
74
+
75
+ def forward(self, x):
76
+ return self.net(x)
77
+
78
+
79
+ def zero_module(module):
80
+ """
81
+ Zero out the parameters of a module and return it.
82
+ """
83
+ for p in module.parameters():
84
+ p.detach().zero_()
85
+ return module
86
+
87
+
88
+ def Normalize(in_channels):
89
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
90
+
91
+
92
+ class SpatialSelfAttention(nn.Module):
93
+ def __init__(self, in_channels):
94
+ super().__init__()
95
+ self.in_channels = in_channels
96
+
97
+ self.norm = Normalize(in_channels)
98
+ self.q = torch.nn.Conv2d(in_channels,
99
+ in_channels,
100
+ kernel_size=1,
101
+ stride=1,
102
+ padding=0)
103
+ self.k = torch.nn.Conv2d(in_channels,
104
+ in_channels,
105
+ kernel_size=1,
106
+ stride=1,
107
+ padding=0)
108
+ self.v = torch.nn.Conv2d(in_channels,
109
+ in_channels,
110
+ kernel_size=1,
111
+ stride=1,
112
+ padding=0)
113
+ self.proj_out = torch.nn.Conv2d(in_channels,
114
+ in_channels,
115
+ kernel_size=1,
116
+ stride=1,
117
+ padding=0)
118
+
119
+ def forward(self, x):
120
+ h_ = x
121
+ h_ = self.norm(h_)
122
+ q = self.q(h_)
123
+ k = self.k(h_)
124
+ v = self.v(h_)
125
+
126
+ # compute attention
127
+ b,c,h,w = q.shape
128
+ q = rearrange(q, 'b c h w -> b (h w) c')
129
+ k = rearrange(k, 'b c h w -> b c (h w)')
130
+ w_ = torch.einsum('bij,bjk->bik', q, k)
131
+
132
+ w_ = w_ * (int(c)**(-0.5))
133
+ w_ = torch.nn.functional.softmax(w_, dim=2)
134
+
135
+ # attend to values
136
+ v = rearrange(v, 'b c h w -> b c (h w)')
137
+ w_ = rearrange(w_, 'b i j -> b j i')
138
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
139
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
140
+ h_ = self.proj_out(h_)
141
+
142
+ return x+h_
143
+
144
+
145
+ class CrossAttention(nn.Module):
146
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
147
+ super().__init__()
148
+ inner_dim = dim_head * heads
149
+ context_dim = default(context_dim, query_dim)
150
+
151
+ self.scale = dim_head ** -0.5
152
+ self.heads = heads
153
+
154
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
155
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
156
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
157
+
158
+ self.to_out = nn.Sequential(
159
+ nn.Linear(inner_dim, query_dim),
160
+ nn.Dropout(dropout)
161
+ )
162
+
163
+ def forward(self, x, context=None, mask=None):
164
+ h = self.heads
165
+
166
+ q = self.to_q(x)
167
+ context = default(context, x)
168
+ k = self.to_k(context)
169
+ v = self.to_v(context)
170
+
171
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
172
+
173
+ # force cast to fp32 to avoid overflowing
174
+ if _ATTN_PRECISION =="fp32":
175
+ with torch.autocast(enabled=False, device_type = 'cuda'):
176
+ q, k = q.float(), k.float()
177
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
178
+ else:
179
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
180
+
181
+ del q, k
182
+
183
+ if exists(mask):
184
+ mask = rearrange(mask, 'b ... -> b (...)')
185
+ max_neg_value = -torch.finfo(sim.dtype).max
186
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
187
+ sim.masked_fill_(~mask, max_neg_value)
188
+
189
+ # attention, what we cannot get enough of
190
+ sim = sim.softmax(dim=-1)
191
+
192
+ out = einsum('b i j, b j d -> b i d', sim, v)
193
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
194
+ return self.to_out(out)
195
+
196
+
197
+ class MemoryEfficientCrossAttention(nn.Module):
198
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
199
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
200
+ super().__init__()
201
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
202
+ f"{heads} heads.")
203
+ inner_dim = dim_head * heads
204
+ context_dim = default(context_dim, query_dim)
205
+
206
+ self.heads = heads
207
+ self.dim_head = dim_head
208
+
209
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
210
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
211
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
212
+
213
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
214
+ self.attention_op: Optional[Any] = None
215
+
216
+ def forward(self, x, context=None, mask=None):
217
+ q = self.to_q(x)
218
+ context = default(context, x)
219
+ k = self.to_k(context)
220
+ v = self.to_v(context)
221
+
222
+ b, _, _ = q.shape
223
+ q, k, v = map(
224
+ lambda t: t.unsqueeze(3)
225
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
226
+ .permute(0, 2, 1, 3)
227
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
228
+ .contiguous(),
229
+ (q, k, v),
230
+ )
231
+
232
+ # actually compute the attention, what we cannot get enough of
233
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
234
+
235
+ if exists(mask):
236
+ raise NotImplementedError
237
+ out = (
238
+ out.unsqueeze(0)
239
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
240
+ .permute(0, 2, 1, 3)
241
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
242
+ )
243
+ return self.to_out(out)
244
+
245
+
246
+ class BasicTransformerBlock(nn.Module):
247
+ ATTENTION_MODES = {
248
+ "softmax": CrossAttention, # vanilla attention
249
+ "softmax-xformers": MemoryEfficientCrossAttention
250
+ }
251
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
252
+ disable_self_attn=False):
253
+ super().__init__()
254
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
255
+ assert attn_mode in self.ATTENTION_MODES
256
+ attn_cls = self.ATTENTION_MODES[attn_mode]
257
+ self.disable_self_attn = disable_self_attn
258
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
259
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
260
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
261
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
262
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
263
+ self.norm1 = nn.LayerNorm(dim)
264
+ self.norm2 = nn.LayerNorm(dim)
265
+ self.norm3 = nn.LayerNorm(dim)
266
+ self.checkpoint = checkpoint
267
+
268
+ def forward(self, x, context=None):
269
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
270
+
271
+ def _forward(self, x, context=None):
272
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
273
+ x = self.attn2(self.norm2(x), context=context) + x
274
+ x = self.ff(self.norm3(x)) + x
275
+ return x
276
+
277
+
278
+ class SpatialTransformer(nn.Module):
279
+ """
280
+ Transformer block for image-like data.
281
+ First, project the input (aka embedding)
282
+ and reshape to b, t, d.
283
+ Then apply standard transformer action.
284
+ Finally, reshape to image
285
+ NEW: use_linear for more efficiency instead of the 1x1 convs
286
+ """
287
+ def __init__(self, in_channels, n_heads, d_head,
288
+ depth=1, dropout=0., context_dim=None,
289
+ disable_self_attn=False, use_linear=False,
290
+ use_checkpoint=True):
291
+ super().__init__()
292
+ if exists(context_dim) and not isinstance(context_dim, list):
293
+ context_dim = [context_dim]
294
+ self.in_channels = in_channels
295
+ inner_dim = n_heads * d_head
296
+ self.norm = Normalize(in_channels)
297
+ if not use_linear:
298
+ self.proj_in = nn.Conv2d(in_channels,
299
+ inner_dim,
300
+ kernel_size=1,
301
+ stride=1,
302
+ padding=0)
303
+ else:
304
+ self.proj_in = nn.Linear(in_channels, inner_dim)
305
+
306
+ self.transformer_blocks = nn.ModuleList(
307
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
308
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
309
+ for d in range(depth)]
310
+ )
311
+ if not use_linear:
312
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
313
+ in_channels,
314
+ kernel_size=1,
315
+ stride=1,
316
+ padding=0))
317
+ else:
318
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
319
+ self.use_linear = use_linear
320
+
321
+ def forward(self, x, context=None):
322
+ # note: if no context is given, cross-attention defaults to self-attention
323
+ if not isinstance(context, list):
324
+ context = [context]
325
+ b, c, h, w = x.shape
326
+ x_in = x
327
+ x = self.norm(x)
328
+ if not self.use_linear:
329
+ x = self.proj_in(x)
330
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
331
+ if self.use_linear:
332
+ x = self.proj_in(x)
333
+ for i, block in enumerate(self.transformer_blocks):
334
+ x = block(x, context=context[i])
335
+ if self.use_linear:
336
+ x = self.proj_out(x)
337
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
338
+ if not self.use_linear:
339
+ x = self.proj_out(x)
340
+ return x + x_in
341
+
ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
ldm/modules/diffusionmodules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (190 Bytes). View file
 
ldm/modules/diffusionmodules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (186 Bytes). View file
 
ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (165 Bytes). View file
 
ldm/modules/diffusionmodules/__pycache__/model.cpython-311.pyc ADDED
Binary file (45.3 kB). View file
 
ldm/modules/diffusionmodules/__pycache__/model.cpython-312.pyc ADDED
Binary file (40.7 kB). View file
 
ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc ADDED
Binary file (15.8 kB). View file
 
ldm/modules/diffusionmodules/__pycache__/util.cpython-311.pyc ADDED
Binary file (16.5 kB). View file
 
ldm/modules/diffusionmodules/__pycache__/util.cpython-312.pyc ADDED
Binary file (14.9 kB). View file
 
ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+ from typing import Optional, Any
8
+
9
+ from ldm.modules.attention import MemoryEfficientCrossAttention
10
+
11
+ try:
12
+ import xformers
13
+ import xformers.ops
14
+ XFORMERS_IS_AVAILBLE = True
15
+ except:
16
+ XFORMERS_IS_AVAILBLE = False
17
+ print("No module 'xformers'. Proceeding without it.")
18
+
19
+
20
+ def get_timestep_embedding(timesteps, embedding_dim):
21
+ """
22
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
23
+ From Fairseq.
24
+ Build sinusoidal embeddings.
25
+ This matches the implementation in tensor2tensor, but differs slightly
26
+ from the description in Section 3.5 of "Attention Is All You Need".
27
+ """
28
+ assert len(timesteps.shape) == 1
29
+
30
+ half_dim = embedding_dim // 2
31
+ emb = math.log(10000) / (half_dim - 1)
32
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
33
+ emb = emb.to(device=timesteps.device)
34
+ emb = timesteps.float()[:, None] * emb[None, :]
35
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
36
+ if embedding_dim % 2 == 1: # zero pad
37
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
38
+ return emb
39
+
40
+
41
+ def nonlinearity(x):
42
+ # swish
43
+ return x*torch.sigmoid(x)
44
+
45
+
46
+ def Normalize(in_channels, num_groups=32):
47
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
48
+
49
+
50
+ class Upsample(nn.Module):
51
+ def __init__(self, in_channels, with_conv):
52
+ super().__init__()
53
+ self.with_conv = with_conv
54
+ if self.with_conv:
55
+ self.conv = torch.nn.Conv2d(in_channels,
56
+ in_channels,
57
+ kernel_size=3,
58
+ stride=1,
59
+ padding=1)
60
+
61
+ def forward(self, x):
62
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
63
+ if self.with_conv:
64
+ x = self.conv(x)
65
+ return x
66
+
67
+
68
+ class Downsample(nn.Module):
69
+ def __init__(self, in_channels, with_conv):
70
+ super().__init__()
71
+ self.with_conv = with_conv
72
+ if self.with_conv:
73
+ # no asymmetric padding in torch conv, must do it ourselves
74
+ self.conv = torch.nn.Conv2d(in_channels,
75
+ in_channels,
76
+ kernel_size=3,
77
+ stride=2,
78
+ padding=0)
79
+
80
+ def forward(self, x):
81
+ if self.with_conv:
82
+ pad = (0,1,0,1)
83
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
84
+ x = self.conv(x)
85
+ else:
86
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
87
+ return x
88
+
89
+
90
+ class ResnetBlock(nn.Module):
91
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
92
+ dropout, temb_channels=512):
93
+ super().__init__()
94
+ self.in_channels = in_channels
95
+ out_channels = in_channels if out_channels is None else out_channels
96
+ self.out_channels = out_channels
97
+ self.use_conv_shortcut = conv_shortcut
98
+
99
+ self.norm1 = Normalize(in_channels)
100
+ self.conv1 = torch.nn.Conv2d(in_channels,
101
+ out_channels,
102
+ kernel_size=3,
103
+ stride=1,
104
+ padding=1)
105
+ if temb_channels > 0:
106
+ self.temb_proj = torch.nn.Linear(temb_channels,
107
+ out_channels)
108
+ self.norm2 = Normalize(out_channels)
109
+ self.dropout = torch.nn.Dropout(dropout)
110
+ self.conv2 = torch.nn.Conv2d(out_channels,
111
+ out_channels,
112
+ kernel_size=3,
113
+ stride=1,
114
+ padding=1)
115
+ if self.in_channels != self.out_channels:
116
+ if self.use_conv_shortcut:
117
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
118
+ out_channels,
119
+ kernel_size=3,
120
+ stride=1,
121
+ padding=1)
122
+ else:
123
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
124
+ out_channels,
125
+ kernel_size=1,
126
+ stride=1,
127
+ padding=0)
128
+
129
+ def forward(self, x, temb):
130
+ h = x
131
+ h = self.norm1(h)
132
+ h = nonlinearity(h)
133
+ h = self.conv1(h)
134
+
135
+ if temb is not None:
136
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
137
+
138
+ h = self.norm2(h)
139
+ h = nonlinearity(h)
140
+ h = self.dropout(h)
141
+ h = self.conv2(h)
142
+
143
+ if self.in_channels != self.out_channels:
144
+ if self.use_conv_shortcut:
145
+ x = self.conv_shortcut(x)
146
+ else:
147
+ x = self.nin_shortcut(x)
148
+
149
+ return x+h
150
+
151
+
152
+ class AttnBlock(nn.Module):
153
+ def __init__(self, in_channels):
154
+ super().__init__()
155
+ self.in_channels = in_channels
156
+
157
+ self.norm = Normalize(in_channels)
158
+ self.q = torch.nn.Conv2d(in_channels,
159
+ in_channels,
160
+ kernel_size=1,
161
+ stride=1,
162
+ padding=0)
163
+ self.k = torch.nn.Conv2d(in_channels,
164
+ in_channels,
165
+ kernel_size=1,
166
+ stride=1,
167
+ padding=0)
168
+ self.v = torch.nn.Conv2d(in_channels,
169
+ in_channels,
170
+ kernel_size=1,
171
+ stride=1,
172
+ padding=0)
173
+ self.proj_out = torch.nn.Conv2d(in_channels,
174
+ in_channels,
175
+ kernel_size=1,
176
+ stride=1,
177
+ padding=0)
178
+
179
+ def forward(self, x):
180
+ h_ = x
181
+ h_ = self.norm(h_)
182
+ q = self.q(h_)
183
+ k = self.k(h_)
184
+ v = self.v(h_)
185
+
186
+ # compute attention
187
+ b,c,h,w = q.shape
188
+ q = q.reshape(b,c,h*w)
189
+ q = q.permute(0,2,1) # b,hw,c
190
+ k = k.reshape(b,c,h*w) # b,c,hw
191
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
192
+ w_ = w_ * (int(c)**(-0.5))
193
+ w_ = torch.nn.functional.softmax(w_, dim=2)
194
+
195
+ # attend to values
196
+ v = v.reshape(b,c,h*w)
197
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
198
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
199
+ h_ = h_.reshape(b,c,h,w)
200
+
201
+ h_ = self.proj_out(h_)
202
+
203
+ return x+h_
204
+
205
+ class MemoryEfficientAttnBlock(nn.Module):
206
+ """
207
+ Uses xformers efficient implementation,
208
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
209
+ Note: this is a single-head self-attention operation
210
+ """
211
+ #
212
+ def __init__(self, in_channels):
213
+ super().__init__()
214
+ self.in_channels = in_channels
215
+
216
+ self.norm = Normalize(in_channels)
217
+ self.q = torch.nn.Conv2d(in_channels,
218
+ in_channels,
219
+ kernel_size=1,
220
+ stride=1,
221
+ padding=0)
222
+ self.k = torch.nn.Conv2d(in_channels,
223
+ in_channels,
224
+ kernel_size=1,
225
+ stride=1,
226
+ padding=0)
227
+ self.v = torch.nn.Conv2d(in_channels,
228
+ in_channels,
229
+ kernel_size=1,
230
+ stride=1,
231
+ padding=0)
232
+ self.proj_out = torch.nn.Conv2d(in_channels,
233
+ in_channels,
234
+ kernel_size=1,
235
+ stride=1,
236
+ padding=0)
237
+ self.attention_op: Optional[Any] = None
238
+
239
+ def forward(self, x):
240
+ h_ = x
241
+ h_ = self.norm(h_)
242
+ q = self.q(h_)
243
+ k = self.k(h_)
244
+ v = self.v(h_)
245
+
246
+ # compute attention
247
+ B, C, H, W = q.shape
248
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) # b x hw x c
249
+
250
+ q, k, v = map(
251
+ lambda t: t.unsqueeze(3) # b x hw x c x 1
252
+ .reshape(B, t.shape[1], 1, C) # b x hw x 1 x c
253
+ .permute(0, 2, 1, 3) # b x 1 x hw x c
254
+ .reshape(B * 1, t.shape[1], C) # b x hw x c
255
+ .contiguous(),
256
+ (q, k, v),
257
+ )
258
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
259
+
260
+ out = (
261
+ out.unsqueeze(0)
262
+ .reshape(B, 1, out.shape[1], C)
263
+ .permute(0, 2, 1, 3)
264
+ .reshape(B, out.shape[1], C)
265
+ )
266
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
267
+ out = self.proj_out(out)
268
+ return x+out
269
+
270
+
271
+ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
272
+ def forward(self, x, context=None, mask=None):
273
+ b, c, h, w = x.shape
274
+ x = rearrange(x, 'b c h w -> b (h w) c')
275
+ out = super().forward(x, context=context, mask=mask)
276
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
277
+ return x + out
278
+
279
+
280
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
281
+ assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
282
+ if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
283
+ attn_type = "vanilla-xformers"
284
+ # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
285
+ if attn_type == "vanilla":
286
+ assert attn_kwargs is None
287
+ return AttnBlock(in_channels)
288
+ elif attn_type == "vanilla-xformers":
289
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
290
+ return MemoryEfficientAttnBlock(in_channels)
291
+ elif type == "memory-efficient-cross-attn":
292
+ attn_kwargs["query_dim"] = in_channels
293
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
294
+ elif attn_type == "none":
295
+ return nn.Identity(in_channels)
296
+ else:
297
+ raise NotImplementedError()
298
+
299
+
300
+ class Model(nn.Module):
301
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
302
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
303
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
304
+ super().__init__()
305
+ if use_linear_attn: attn_type = "linear"
306
+ self.ch = ch
307
+ self.temb_ch = self.ch*4
308
+ self.num_resolutions = len(ch_mult)
309
+ self.num_res_blocks = num_res_blocks
310
+ self.resolution = resolution
311
+ self.in_channels = in_channels
312
+
313
+ self.use_timestep = use_timestep
314
+ if self.use_timestep:
315
+ # timestep embedding
316
+ self.temb = nn.Module()
317
+ self.temb.dense = nn.ModuleList([
318
+ torch.nn.Linear(self.ch,
319
+ self.temb_ch),
320
+ torch.nn.Linear(self.temb_ch,
321
+ self.temb_ch),
322
+ ])
323
+
324
+ # downsampling
325
+ self.conv_in = torch.nn.Conv2d(in_channels,
326
+ self.ch,
327
+ kernel_size=3,
328
+ stride=1,
329
+ padding=1)
330
+
331
+ curr_res = resolution
332
+ in_ch_mult = (1,)+tuple(ch_mult)
333
+ self.down = nn.ModuleList()
334
+ for i_level in range(self.num_resolutions):
335
+ block = nn.ModuleList()
336
+ attn = nn.ModuleList()
337
+ block_in = ch*in_ch_mult[i_level]
338
+ block_out = ch*ch_mult[i_level]
339
+ for i_block in range(self.num_res_blocks):
340
+ block.append(ResnetBlock(in_channels=block_in,
341
+ out_channels=block_out,
342
+ temb_channels=self.temb_ch,
343
+ dropout=dropout))
344
+ block_in = block_out
345
+ if curr_res in attn_resolutions:
346
+ attn.append(make_attn(block_in, attn_type=attn_type))
347
+ down = nn.Module()
348
+ down.block = block
349
+ down.attn = attn
350
+ if i_level != self.num_resolutions-1:
351
+ down.downsample = Downsample(block_in, resamp_with_conv)
352
+ curr_res = curr_res // 2
353
+ self.down.append(down)
354
+
355
+ # middle
356
+ self.mid = nn.Module()
357
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
358
+ out_channels=block_in,
359
+ temb_channels=self.temb_ch,
360
+ dropout=dropout)
361
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
362
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
363
+ out_channels=block_in,
364
+ temb_channels=self.temb_ch,
365
+ dropout=dropout)
366
+
367
+ # upsampling
368
+ self.up = nn.ModuleList()
369
+ for i_level in reversed(range(self.num_resolutions)):
370
+ block = nn.ModuleList()
371
+ attn = nn.ModuleList()
372
+ block_out = ch*ch_mult[i_level]
373
+ skip_in = ch*ch_mult[i_level]
374
+ for i_block in range(self.num_res_blocks+1):
375
+ if i_block == self.num_res_blocks:
376
+ skip_in = ch*in_ch_mult[i_level]
377
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
378
+ out_channels=block_out,
379
+ temb_channels=self.temb_ch,
380
+ dropout=dropout))
381
+ block_in = block_out
382
+ if curr_res in attn_resolutions:
383
+ attn.append(make_attn(block_in, attn_type=attn_type))
384
+ up = nn.Module()
385
+ up.block = block
386
+ up.attn = attn
387
+ if i_level != 0:
388
+ up.upsample = Upsample(block_in, resamp_with_conv)
389
+ curr_res = curr_res * 2
390
+ self.up.insert(0, up) # prepend to get consistent order
391
+
392
+ # end
393
+ self.norm_out = Normalize(block_in)
394
+ self.conv_out = torch.nn.Conv2d(block_in,
395
+ out_ch,
396
+ kernel_size=3,
397
+ stride=1,
398
+ padding=1)
399
+
400
+ def forward(self, x, t=None, context=None):
401
+ #assert x.shape[2] == x.shape[3] == self.resolution
402
+ if context is not None:
403
+ # assume aligned context, cat along channel axis
404
+ x = torch.cat((x, context), dim=1)
405
+ if self.use_timestep:
406
+ # timestep embedding
407
+ assert t is not None
408
+ temb = get_timestep_embedding(t, self.ch)
409
+ temb = self.temb.dense[0](temb)
410
+ temb = nonlinearity(temb)
411
+ temb = self.temb.dense[1](temb)
412
+ else:
413
+ temb = None
414
+
415
+ # downsampling
416
+ hs = [self.conv_in(x)]
417
+ for i_level in range(self.num_resolutions):
418
+ for i_block in range(self.num_res_blocks):
419
+ h = self.down[i_level].block[i_block](hs[-1], temb)
420
+ if len(self.down[i_level].attn) > 0:
421
+ h = self.down[i_level].attn[i_block](h)
422
+ hs.append(h)
423
+ if i_level != self.num_resolutions-1:
424
+ hs.append(self.down[i_level].downsample(hs[-1]))
425
+
426
+ # middle
427
+ h = hs[-1]
428
+ h = self.mid.block_1(h, temb)
429
+ h = self.mid.attn_1(h)
430
+ h = self.mid.block_2(h, temb)
431
+
432
+ # upsampling
433
+ for i_level in reversed(range(self.num_resolutions)):
434
+ for i_block in range(self.num_res_blocks+1):
435
+ h = self.up[i_level].block[i_block](
436
+ torch.cat([h, hs.pop()], dim=1), temb)
437
+ if len(self.up[i_level].attn) > 0:
438
+ h = self.up[i_level].attn[i_block](h)
439
+ if i_level != 0:
440
+ h = self.up[i_level].upsample(h)
441
+
442
+ # end
443
+ h = self.norm_out(h)
444
+ h = nonlinearity(h)
445
+ h = self.conv_out(h)
446
+ return h
447
+
448
+ def get_last_layer(self):
449
+ return self.conv_out.weight
450
+
451
+
452
+ class Encoder(nn.Module):
453
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
454
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
455
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
456
+ **ignore_kwargs):
457
+ super().__init__()
458
+ if use_linear_attn: attn_type = "linear"
459
+ self.ch = ch
460
+ self.temb_ch = 0
461
+ self.num_resolutions = len(ch_mult)
462
+ self.resolution = resolution
463
+ self.in_channels = in_channels
464
+ if isinstance(num_res_blocks, int):
465
+ num_res_blocks = [num_res_blocks, ] * len(ch_mult)
466
+ else:
467
+ assert len(num_res_blocks) == len(ch_mult)
468
+ self.num_res_blocks = num_res_blocks
469
+
470
+ # downsampling
471
+ self.conv_in = torch.nn.Conv2d(in_channels,
472
+ self.ch,
473
+ kernel_size=3,
474
+ stride=1,
475
+ padding=1)
476
+
477
+ curr_res = resolution
478
+ in_ch_mult = (1,)+tuple(ch_mult)
479
+ self.in_ch_mult = in_ch_mult
480
+ self.down = nn.ModuleList()
481
+ for i_level in range(self.num_resolutions):
482
+ block = nn.ModuleList()
483
+ attn = nn.ModuleList()
484
+ block_in = ch*in_ch_mult[i_level]
485
+ block_out = ch*ch_mult[i_level]
486
+ for i_block in range(self.num_res_blocks[i_level]):
487
+ block.append(ResnetBlock(in_channels=block_in,
488
+ out_channels=block_out,
489
+ temb_channels=self.temb_ch,
490
+ dropout=dropout))
491
+ block_in = block_out
492
+ if curr_res in attn_resolutions:
493
+ attn.append(make_attn(block_in, attn_type=attn_type))
494
+ down = nn.Module()
495
+ down.block = block
496
+ down.attn = attn
497
+ if i_level != self.num_resolutions-1:
498
+ down.downsample = Downsample(block_in, resamp_with_conv)
499
+ curr_res = curr_res // 2
500
+ self.down.append(down)
501
+
502
+ # middle
503
+ self.mid = nn.Module()
504
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
505
+ out_channels=block_in,
506
+ temb_channels=self.temb_ch,
507
+ dropout=dropout)
508
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
509
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
510
+ out_channels=block_in,
511
+ temb_channels=self.temb_ch,
512
+ dropout=dropout)
513
+
514
+ # end
515
+ self.norm_out = Normalize(block_in)
516
+ self.conv_out = torch.nn.Conv2d(block_in,
517
+ 2*z_channels if double_z else z_channels,
518
+ kernel_size=3,
519
+ stride=1,
520
+ padding=1)
521
+
522
+ def forward(self, x):
523
+ # timestep embedding
524
+ temb = None
525
+
526
+ # downsampling
527
+ hs = [self.conv_in(x)]
528
+ for i_level in range(self.num_resolutions):
529
+ for i_block in range(self.num_res_blocks[i_level]):
530
+ h = self.down[i_level].block[i_block](hs[-1], temb)
531
+ if len(self.down[i_level].attn) > 0:
532
+ h = self.down[i_level].attn[i_block](h)
533
+ hs.append(h)
534
+ if i_level != self.num_resolutions-1:
535
+ hs.append(self.down[i_level].downsample(hs[-1]))
536
+
537
+ # middle
538
+ h = hs[-1]
539
+ h = self.mid.block_1(h, temb)
540
+ h = self.mid.attn_1(h)
541
+ h = self.mid.block_2(h, temb)
542
+
543
+ # end
544
+ h = self.norm_out(h)
545
+ h = nonlinearity(h)
546
+ h = self.conv_out(h)
547
+ return h
548
+
549
+
550
+ class Decoder(nn.Module):
551
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
552
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
553
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
554
+ attn_type="vanilla", **ignorekwargs):
555
+ super().__init__()
556
+ if use_linear_attn: attn_type = "linear"
557
+ self.ch = ch
558
+ self.temb_ch = 0
559
+ self.num_resolutions = len(ch_mult)
560
+ self.resolution = resolution
561
+ self.in_channels = in_channels
562
+ self.give_pre_end = give_pre_end
563
+ self.tanh_out = tanh_out
564
+ if isinstance(num_res_blocks, int):
565
+ num_res_blocks = [num_res_blocks, ] * len(ch_mult)
566
+ else:
567
+ assert len(num_res_blocks) == len(ch_mult)
568
+ self.num_res_blocks = num_res_blocks
569
+
570
+ # compute in_ch_mult, block_in and curr_res at lowest res
571
+ in_ch_mult = (1,)+tuple(ch_mult)
572
+ block_in = ch*ch_mult[self.num_resolutions-1]
573
+ curr_res = resolution // 2**(self.num_resolutions-1)
574
+ self.z_shape = (1,z_channels,curr_res,curr_res)
575
+ print("Working with z of shape {} = {} dimensions.".format(
576
+ self.z_shape, np.prod(self.z_shape)))
577
+
578
+ # z to block_in
579
+ self.conv_in = torch.nn.Conv2d(z_channels,
580
+ block_in,
581
+ kernel_size=3,
582
+ stride=1,
583
+ padding=1)
584
+
585
+ # middle
586
+ self.mid = nn.Module()
587
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
588
+ out_channels=block_in,
589
+ temb_channels=self.temb_ch,
590
+ dropout=dropout)
591
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
592
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
593
+ out_channels=block_in,
594
+ temb_channels=self.temb_ch,
595
+ dropout=dropout)
596
+
597
+ # upsampling
598
+ self.up = nn.ModuleList()
599
+ for i_level in reversed(range(self.num_resolutions)):
600
+ block = nn.ModuleList()
601
+ attn = nn.ModuleList()
602
+ block_out = ch*ch_mult[i_level]
603
+ for i_block in range(self.num_res_blocks[i_level]+1):
604
+ block.append(ResnetBlock(in_channels=block_in,
605
+ out_channels=block_out,
606
+ temb_channels=self.temb_ch,
607
+ dropout=dropout))
608
+ block_in = block_out
609
+ if curr_res in attn_resolutions:
610
+ attn.append(make_attn(block_in, attn_type=attn_type))
611
+ up = nn.Module()
612
+ up.block = block
613
+ up.attn = attn
614
+ if i_level != 0:
615
+ up.upsample = Upsample(block_in, resamp_with_conv)
616
+ curr_res = curr_res * 2
617
+ self.up.insert(0, up) # prepend to get consistent order
618
+
619
+ # end
620
+ self.norm_out = Normalize(block_in)
621
+ self.conv_out = torch.nn.Conv2d(block_in,
622
+ out_ch,
623
+ kernel_size=3,
624
+ stride=1,
625
+ padding=1)
626
+
627
+ def forward(self, z):
628
+ #assert z.shape[1:] == self.z_shape[1:]
629
+ self.last_z_shape = z.shape
630
+
631
+ # timestep embedding
632
+ temb = None
633
+
634
+ # z to block_in
635
+ h = self.conv_in(z)
636
+
637
+ # middle
638
+ h = self.mid.block_1(h, temb)
639
+ h = self.mid.attn_1(h)
640
+ h = self.mid.block_2(h, temb)
641
+
642
+ # upsampling
643
+ for i_level in reversed(range(self.num_resolutions)):
644
+ for i_block in range(self.num_res_blocks[i_level]+1):
645
+ h = self.up[i_level].block[i_block](h, temb)
646
+ if len(self.up[i_level].attn) > 0:
647
+ h = self.up[i_level].attn[i_block](h)
648
+ if i_level != 0:
649
+ h = self.up[i_level].upsample(h)
650
+
651
+ # end
652
+ if self.give_pre_end:
653
+ return h
654
+
655
+ h = self.norm_out(h)
656
+ h = nonlinearity(h)
657
+ h = self.conv_out(h)
658
+ if self.tanh_out:
659
+ h = torch.tanh(h)
660
+ return h
661
+
662
+
663
+ class SimpleDecoder(nn.Module):
664
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
665
+ super().__init__()
666
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
667
+ ResnetBlock(in_channels=in_channels,
668
+ out_channels=2 * in_channels,
669
+ temb_channels=0, dropout=0.0),
670
+ ResnetBlock(in_channels=2 * in_channels,
671
+ out_channels=4 * in_channels,
672
+ temb_channels=0, dropout=0.0),
673
+ ResnetBlock(in_channels=4 * in_channels,
674
+ out_channels=2 * in_channels,
675
+ temb_channels=0, dropout=0.0),
676
+ nn.Conv2d(2*in_channels, in_channels, 1),
677
+ Upsample(in_channels, with_conv=True)])
678
+ # end
679
+ self.norm_out = Normalize(in_channels)
680
+ self.conv_out = torch.nn.Conv2d(in_channels,
681
+ out_channels,
682
+ kernel_size=3,
683
+ stride=1,
684
+ padding=1)
685
+
686
+ def forward(self, x):
687
+ for i, layer in enumerate(self.model):
688
+ if i in [1,2,3]:
689
+ x = layer(x, None)
690
+ else:
691
+ x = layer(x)
692
+
693
+ h = self.norm_out(x)
694
+ h = nonlinearity(h)
695
+ x = self.conv_out(h)
696
+ return x
697
+
698
+
699
+ class UpsampleDecoder(nn.Module):
700
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
701
+ ch_mult=(2,2), dropout=0.0):
702
+ super().__init__()
703
+ # upsampling
704
+ self.temb_ch = 0
705
+ self.num_resolutions = len(ch_mult)
706
+ self.num_res_blocks = num_res_blocks
707
+ block_in = in_channels
708
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
709
+ self.res_blocks = nn.ModuleList()
710
+ self.upsample_blocks = nn.ModuleList()
711
+ for i_level in range(self.num_resolutions):
712
+ res_block = []
713
+ block_out = ch * ch_mult[i_level]
714
+ for i_block in range(self.num_res_blocks + 1):
715
+ res_block.append(ResnetBlock(in_channels=block_in,
716
+ out_channels=block_out,
717
+ temb_channels=self.temb_ch,
718
+ dropout=dropout))
719
+ block_in = block_out
720
+ self.res_blocks.append(nn.ModuleList(res_block))
721
+ if i_level != self.num_resolutions - 1:
722
+ self.upsample_blocks.append(Upsample(block_in, True))
723
+ curr_res = curr_res * 2
724
+
725
+ # end
726
+ self.norm_out = Normalize(block_in)
727
+ self.conv_out = torch.nn.Conv2d(block_in,
728
+ out_channels,
729
+ kernel_size=3,
730
+ stride=1,
731
+ padding=1)
732
+
733
+ def forward(self, x):
734
+ # upsampling
735
+ h = x
736
+ for k, i_level in enumerate(range(self.num_resolutions)):
737
+ for i_block in range(self.num_res_blocks + 1):
738
+ h = self.res_blocks[i_level][i_block](h, None)
739
+ if i_level != self.num_resolutions - 1:
740
+ h = self.upsample_blocks[k](h)
741
+ h = self.norm_out(h)
742
+ h = nonlinearity(h)
743
+ h = self.conv_out(h)
744
+ return h
745
+
746
+
747
+ class LatentRescaler(nn.Module):
748
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
749
+ super().__init__()
750
+ # residual block, interpolate, residual block
751
+ self.factor = factor
752
+ self.conv_in = nn.Conv2d(in_channels,
753
+ mid_channels,
754
+ kernel_size=3,
755
+ stride=1,
756
+ padding=1)
757
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
758
+ out_channels=mid_channels,
759
+ temb_channels=0,
760
+ dropout=0.0) for _ in range(depth)])
761
+ self.attn = AttnBlock(mid_channels)
762
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
763
+ out_channels=mid_channels,
764
+ temb_channels=0,
765
+ dropout=0.0) for _ in range(depth)])
766
+
767
+ self.conv_out = nn.Conv2d(mid_channels,
768
+ out_channels,
769
+ kernel_size=1,
770
+ )
771
+
772
+ def forward(self, x):
773
+ x = self.conv_in(x)
774
+ for block in self.res_block1:
775
+ x = block(x, None)
776
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
777
+ x = self.attn(x)
778
+ for block in self.res_block2:
779
+ x = block(x, None)
780
+ x = self.conv_out(x)
781
+ return x
782
+
783
+
784
+ class MergedRescaleEncoder(nn.Module):
785
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
786
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
787
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
788
+ super().__init__()
789
+ intermediate_chn = ch * ch_mult[-1]
790
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
791
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
792
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
793
+ out_ch=None)
794
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
795
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
796
+
797
+ def forward(self, x):
798
+ x = self.encoder(x)
799
+ x = self.rescaler(x)
800
+ return x
801
+
802
+
803
+ class MergedRescaleDecoder(nn.Module):
804
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
805
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
806
+ super().__init__()
807
+ tmp_chn = z_channels*ch_mult[-1]
808
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
809
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
810
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
811
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
812
+ out_channels=tmp_chn, depth=rescale_module_depth)
813
+
814
+ def forward(self, x):
815
+ x = self.rescaler(x)
816
+ x = self.decoder(x)
817
+ return x
818
+
819
+
820
+ class Upsampler(nn.Module):
821
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
822
+ super().__init__()
823
+ assert out_size >= in_size
824
+ num_blocks = int(np.log2(out_size//in_size))+1
825
+ factor_up = 1.+ (out_size % in_size)
826
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
827
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
828
+ out_channels=in_channels)
829
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
830
+ attn_resolutions=[], in_channels=None, ch=in_channels,
831
+ ch_mult=[ch_mult for _ in range(num_blocks)])
832
+
833
+ def forward(self, x):
834
+ x = self.rescaler(x)
835
+ x = self.decoder(x)
836
+ return x
837
+
838
+
839
+ class Resize(nn.Module):
840
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
841
+ super().__init__()
842
+ self.with_conv = learned
843
+ self.mode = mode
844
+ if self.with_conv:
845
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
846
+ raise NotImplementedError()
847
+ assert in_channels is not None
848
+ # no asymmetric padding in torch conv, must do it ourselves
849
+ self.conv = torch.nn.Conv2d(in_channels,
850
+ in_channels,
851
+ kernel_size=4,
852
+ stride=2,
853
+ padding=1)
854
+
855
+ def forward(self, x, scale_factor=1.0):
856
+ if scale_factor==1.0:
857
+ return x
858
+ else:
859
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
860
+ return x
ldm/modules/diffusionmodules/model_back.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+
7
+ def get_timestep_embedding(timesteps, embedding_dim):
8
+ """
9
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
10
+ From Fairseq.
11
+ Build sinusoidal embeddings.
12
+ This matches the implementation in tensor2tensor, but differs slightly
13
+ from the description in Section 3.5 of "Attention Is All You Need".
14
+ """
15
+ assert len(timesteps.shape) == 1
16
+
17
+ half_dim = embedding_dim // 2
18
+ emb = math.log(10000) / (half_dim - 1)
19
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
20
+ emb = emb.to(device=timesteps.device)
21
+ emb = timesteps.float()[:, None] * emb[None, :]
22
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
23
+ if embedding_dim % 2 == 1: # zero pad
24
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
25
+ return emb
26
+
27
+
28
+ def nonlinearity(x):
29
+ # swish
30
+ return x*torch.sigmoid(x)
31
+
32
+
33
+ def Normalize(in_channels):
34
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
35
+
36
+
37
+ class Upsample(nn.Module):
38
+ def __init__(self, in_channels, with_conv, padding_mode):
39
+ super().__init__()
40
+ self.with_conv = with_conv
41
+ if self.with_conv:
42
+ self.conv = torch.nn.Conv2d(in_channels,
43
+ in_channels,
44
+ kernel_size=3,
45
+ stride=1,
46
+ padding=1,
47
+ padding_mode=padding_mode)
48
+
49
+ def forward(self, x):
50
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
51
+ if self.with_conv:
52
+ x = self.conv(x)
53
+ return x
54
+
55
+
56
+ class Downsample(nn.Module):
57
+ def __init__(self, in_channels, with_conv):
58
+ super().__init__()
59
+ self.with_conv = with_conv
60
+ if self.with_conv:
61
+ # no asymmetric padding in torch conv, must do it ourselves
62
+ self.conv = torch.nn.Conv2d(in_channels,
63
+ in_channels,
64
+ kernel_size=3,
65
+ stride=2,
66
+ padding=0)
67
+
68
+ def forward(self, x):
69
+ if self.with_conv:
70
+ pad = (0,1,0,1)
71
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
72
+ x = self.conv(x)
73
+ else:
74
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
75
+ return x
76
+
77
+
78
+ class ResnetBlock(nn.Module):
79
+ def __init__(self, *, in_channels, padding_mode, out_channels=None, conv_shortcut=False,
80
+ dropout, temb_channels=512):
81
+ super().__init__()
82
+ self.in_channels = in_channels
83
+ out_channels = in_channels if out_channels is None else out_channels
84
+ self.out_channels = out_channels
85
+ self.use_conv_shortcut = conv_shortcut
86
+
87
+ self.norm1 = Normalize(in_channels)
88
+ self.conv1 = torch.nn.Conv2d(in_channels,
89
+ out_channels,
90
+ kernel_size=3,
91
+ stride=1,
92
+ padding=1,
93
+ padding_mode=padding_mode)
94
+ if temb_channels > 0:
95
+ self.temb_proj = torch.nn.Linear(temb_channels,
96
+ out_channels)
97
+ self.norm2 = Normalize(out_channels)
98
+ self.dropout = torch.nn.Dropout(dropout)
99
+ self.conv2 = torch.nn.Conv2d(out_channels,
100
+ out_channels,
101
+ kernel_size=3,
102
+ stride=1,
103
+ padding=1,
104
+ padding_mode=padding_mode)
105
+ if self.in_channels != self.out_channels:
106
+ if self.use_conv_shortcut:
107
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
108
+ out_channels,
109
+ kernel_size=3,
110
+ stride=1,
111
+ padding=1,
112
+ padding_mode=padding_mode)
113
+ else:
114
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
115
+ out_channels,
116
+ kernel_size=1,
117
+ stride=1,
118
+ padding=0)
119
+
120
+ def forward(self, x, temb):
121
+ h = x
122
+ h = self.norm1(h)
123
+ h = nonlinearity(h)
124
+ h = self.conv1(h)
125
+
126
+ if temb is not None:
127
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
128
+
129
+ h = self.norm2(h)
130
+ h = nonlinearity(h)
131
+ h = self.dropout(h)
132
+ h = self.conv2(h)
133
+
134
+ if self.in_channels != self.out_channels:
135
+ if self.use_conv_shortcut:
136
+ x = self.conv_shortcut(x)
137
+ else:
138
+ x = self.nin_shortcut(x)
139
+
140
+ return x+h
141
+
142
+
143
+ class AttnBlock(nn.Module):
144
+ def __init__(self, in_channels):
145
+ super().__init__()
146
+ self.in_channels = in_channels
147
+
148
+ self.norm = Normalize(in_channels)
149
+ self.q = torch.nn.Conv2d(in_channels,
150
+ in_channels,
151
+ kernel_size=1,
152
+ stride=1,
153
+ padding=0)
154
+ self.k = torch.nn.Conv2d(in_channels,
155
+ in_channels,
156
+ kernel_size=1,
157
+ stride=1,
158
+ padding=0)
159
+ self.v = torch.nn.Conv2d(in_channels,
160
+ in_channels,
161
+ kernel_size=1,
162
+ stride=1,
163
+ padding=0)
164
+ self.proj_out = torch.nn.Conv2d(in_channels,
165
+ in_channels,
166
+ kernel_size=1,
167
+ stride=1,
168
+ padding=0)
169
+
170
+
171
+ def forward(self, x):
172
+ h_ = x
173
+ h_ = self.norm(h_)
174
+ q = self.q(h_)
175
+ k = self.k(h_)
176
+ v = self.v(h_)
177
+
178
+ # compute attention
179
+ b,c,h,w = q.shape
180
+ q = q.reshape(b,c,h*w)
181
+ q = q.permute(0,2,1) # b,hw,c
182
+ k = k.reshape(b,c,h*w) # b,c,hw
183
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
184
+ w_ = w_ * (int(c)**(-0.5))
185
+ w_ = torch.nn.functional.softmax(w_, dim=2)
186
+
187
+ # attend to values
188
+ v = v.reshape(b,c,h*w)
189
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
190
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
191
+ h_ = h_.reshape(b,c,h,w)
192
+
193
+ h_ = self.proj_out(h_)
194
+
195
+ return x+h_
196
+
197
+
198
+ class Model(nn.Module):
199
+ def __init__(self, *, ch, out_ch, padding_mode, ch_mult=(1,2,4,8), num_res_blocks,
200
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
201
+ resolution, use_timestep=True):
202
+ super().__init__()
203
+ self.ch = ch
204
+ self.temb_ch = self.ch*4
205
+ self.num_resolutions = len(ch_mult)
206
+ self.num_res_blocks = num_res_blocks
207
+ self.resolution = resolution
208
+ self.in_channels = in_channels
209
+
210
+ self.use_timestep = use_timestep
211
+ if self.use_timestep:
212
+ # timestep embedding
213
+ self.temb = nn.Module()
214
+ self.temb.dense = nn.ModuleList([
215
+ torch.nn.Linear(self.ch,
216
+ self.temb_ch),
217
+ torch.nn.Linear(self.temb_ch,
218
+ self.temb_ch),
219
+ ])
220
+
221
+ # downsampling
222
+ self.conv_in = torch.nn.Conv2d(in_channels,
223
+ self.ch,
224
+ kernel_size=3,
225
+ stride=1,
226
+ padding=1,
227
+ padding_mode=padding_mode)
228
+
229
+ curr_res = resolution
230
+ in_ch_mult = (1,)+tuple(ch_mult)
231
+ self.down = nn.ModuleList()
232
+ for i_level in range(self.num_resolutions):
233
+ block = nn.ModuleList()
234
+ attn = nn.ModuleList()
235
+ block_in = ch*in_ch_mult[i_level]
236
+ block_out = ch*ch_mult[i_level]
237
+ for i_block in range(self.num_res_blocks):
238
+ block.append(ResnetBlock(in_channels=block_in,
239
+ padding_mode=padding_mode,
240
+ out_channels=block_out,
241
+ temb_channels=self.temb_ch,
242
+ dropout=dropout))
243
+ block_in = block_out
244
+ if curr_res in attn_resolutions:
245
+ attn.append(AttnBlock(block_in))
246
+ down = nn.Module()
247
+ down.block = block
248
+ down.attn = attn
249
+ if i_level != self.num_resolutions-1:
250
+ down.downsample = Downsample(block_in, resamp_with_conv)
251
+ curr_res = curr_res // 2
252
+ self.down.append(down)
253
+
254
+ # middle
255
+ self.mid = nn.Module()
256
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
257
+ padding_mode=padding_mode,
258
+ out_channels=block_in,
259
+ temb_channels=self.temb_ch,
260
+ dropout=dropout)
261
+ self.mid.attn_1 = AttnBlock(block_in)
262
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
263
+ out_channels=block_in,
264
+ padding_mode=padding_mode,
265
+ temb_channels=self.temb_ch,
266
+ dropout=dropout)
267
+
268
+ # upsampling
269
+ self.up = nn.ModuleList()
270
+ for i_level in reversed(range(self.num_resolutions)):
271
+ block = nn.ModuleList()
272
+ attn = nn.ModuleList()
273
+ block_out = ch*ch_mult[i_level]
274
+ skip_in = ch*ch_mult[i_level]
275
+ for i_block in range(self.num_res_blocks+1):
276
+ if i_block == self.num_res_blocks:
277
+ skip_in = ch*in_ch_mult[i_level]
278
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
279
+ padding_mode=padding_mode,
280
+ out_channels=block_out,
281
+ temb_channels=self.temb_ch,
282
+ dropout=dropout))
283
+ block_in = block_out
284
+ if curr_res in attn_resolutions:
285
+ attn.append(AttnBlock(block_in))
286
+ up = nn.Module()
287
+ up.block = block
288
+ up.attn = attn
289
+ if i_level != 0:
290
+ up.upsample = Upsample(block_in, resamp_with_conv, padding_mode)
291
+ curr_res = curr_res * 2
292
+ self.up.insert(0, up) # prepend to get consistent order
293
+
294
+ # end
295
+ self.norm_out = Normalize(block_in)
296
+ self.conv_out = torch.nn.Conv2d(block_in,
297
+ out_ch,
298
+ kernel_size=3,
299
+ stride=1,
300
+ padding=1,
301
+ padding_mode=padding_mode)
302
+
303
+
304
+ def forward(self, x, t=None):
305
+ #assert x.shape[2] == x.shape[3] == self.resolution
306
+
307
+ if self.use_timestep:
308
+ # timestep embedding
309
+ assert t is not None
310
+ temb = get_timestep_embedding(t, self.ch)
311
+ temb = self.temb.dense[0](temb)
312
+ temb = nonlinearity(temb)
313
+ temb = self.temb.dense[1](temb)
314
+ else:
315
+ temb = None
316
+
317
+ # downsampling
318
+ hs = [self.conv_in(x)]
319
+ for i_level in range(self.num_resolutions):
320
+ for i_block in range(self.num_res_blocks):
321
+ h = self.down[i_level].block[i_block](hs[-1], temb)
322
+ if len(self.down[i_level].attn) > 0:
323
+ h = self.down[i_level].attn[i_block](h)
324
+ hs.append(h)
325
+ if i_level != self.num_resolutions-1:
326
+ hs.append(self.down[i_level].downsample(hs[-1]))
327
+
328
+ # middle
329
+ h = hs[-1]
330
+ h = self.mid.block_1(h, temb)
331
+ h = self.mid.attn_1(h)
332
+ h = self.mid.block_2(h, temb)
333
+
334
+ # upsampling
335
+ for i_level in reversed(range(self.num_resolutions)):
336
+ for i_block in range(self.num_res_blocks+1):
337
+ h = self.up[i_level].block[i_block](
338
+ torch.cat([h, hs.pop()], dim=1), temb)
339
+ if len(self.up[i_level].attn) > 0:
340
+ h = self.up[i_level].attn[i_block](h)
341
+ if i_level != 0:
342
+ h = self.up[i_level].upsample(h)
343
+
344
+ # end
345
+ h = self.norm_out(h)
346
+ h = nonlinearity(h)
347
+ h = self.conv_out(h)
348
+ return h
349
+
350
+
351
+ class Encoder(nn.Module):
352
+ def __init__(self, *, ch, out_ch, padding_mode='zeros', ch_mult=(1,2,4,8), num_res_blocks,
353
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
354
+ resolution, z_channels, double_z=True, **ignore_kwargs):
355
+ super().__init__()
356
+ self.ch = ch
357
+ self.temb_ch = 0
358
+ self.num_resolutions = len(ch_mult)
359
+ self.resolution = resolution
360
+ self.in_channels = in_channels
361
+ if isinstance(num_res_blocks, int):
362
+ num_res_blocks = [num_res_blocks, ] * len(ch_mult)
363
+ else:
364
+ assert len(num_res_blocks) == len(ch_mult)
365
+ self.num_res_blocks = num_res_blocks
366
+
367
+ # downsampling
368
+ self.conv_in = torch.nn.Conv2d(in_channels,
369
+ self.ch,
370
+ kernel_size=3,
371
+ stride=1,
372
+ padding=1,
373
+ padding_mode=padding_mode)
374
+
375
+ curr_res = resolution
376
+ in_ch_mult = (1,)+tuple(ch_mult)
377
+ self.down = nn.ModuleList()
378
+ for i_level in range(self.num_resolutions):
379
+ block = nn.ModuleList()
380
+ attn = nn.ModuleList()
381
+ block_in = ch*in_ch_mult[i_level]
382
+ block_out = ch*ch_mult[i_level]
383
+ for i_block in range(self.num_res_blocks[i_level]):
384
+ block.append(ResnetBlock(in_channels=block_in,
385
+ padding_mode=padding_mode,
386
+ out_channels=block_out,
387
+ temb_channels=self.temb_ch,
388
+ dropout=dropout))
389
+ block_in = block_out
390
+ if curr_res in attn_resolutions:
391
+ attn.append(AttnBlock(block_in))
392
+ down = nn.Module()
393
+ down.block = block
394
+ down.attn = attn
395
+ if i_level != self.num_resolutions-1:
396
+ down.downsample = Downsample(block_in, resamp_with_conv)
397
+ curr_res = curr_res // 2
398
+ self.down.append(down)
399
+
400
+ # middle
401
+ self.mid = nn.Module()
402
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
403
+ padding_mode=padding_mode,
404
+ out_channels=block_in,
405
+ temb_channels=self.temb_ch,
406
+ dropout=dropout)
407
+ self.mid.attn_1 = AttnBlock(block_in)
408
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
409
+ padding_mode=padding_mode,
410
+ out_channels=block_in,
411
+ temb_channels=self.temb_ch,
412
+ dropout=dropout)
413
+
414
+ # end
415
+ self.norm_out = Normalize(block_in)
416
+ self.conv_out = torch.nn.Conv2d(block_in,
417
+ 2*z_channels if double_z else z_channels,
418
+ kernel_size=3,
419
+ stride=1,
420
+ padding=1,
421
+ padding_mode=padding_mode)
422
+
423
+
424
+ def forward(self, x):
425
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
426
+
427
+ # timestep embedding
428
+ temb = None
429
+
430
+ # downsampling
431
+ hs = [self.conv_in(x)]
432
+ for i_level in range(self.num_resolutions):
433
+ for i_block in range(self.num_res_blocks[i_level]):
434
+ h = self.down[i_level].block[i_block](hs[-1], temb)
435
+ if len(self.down[i_level].attn) > 0:
436
+ h = self.down[i_level].attn[i_block](h)
437
+ hs.append(h)
438
+ if i_level != self.num_resolutions-1:
439
+ hs.append(self.down[i_level].downsample(hs[-1]))
440
+
441
+ # middle
442
+ h = hs[-1]
443
+ h = self.mid.block_1(h, temb)
444
+ h = self.mid.attn_1(h)
445
+ h = self.mid.block_2(h, temb)
446
+
447
+ # end
448
+ h = self.norm_out(h)
449
+ h = nonlinearity(h)
450
+ h = self.conv_out(h)
451
+ return h
452
+
453
+
454
+ class Decoder(nn.Module):
455
+ def __init__(self, *, ch, out_ch, padding_mode='zeros', ch_mult=(1,2,4,8), num_res_blocks,
456
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
457
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
458
+ super().__init__()
459
+ self.ch = ch
460
+ self.temb_ch = 0
461
+ self.num_resolutions = len(ch_mult)
462
+ self.resolution = resolution
463
+ self.in_channels = in_channels
464
+ self.give_pre_end = give_pre_end
465
+ if isinstance(num_res_blocks, int):
466
+ num_res_blocks = [num_res_blocks, ] * len(ch_mult)
467
+ else:
468
+ assert len(num_res_blocks) == len(ch_mult)
469
+ self.num_res_blocks = num_res_blocks
470
+
471
+ # compute in_ch_mult, block_in and curr_res at lowest res
472
+ in_ch_mult = (1,)+tuple(ch_mult)
473
+ block_in = ch*ch_mult[self.num_resolutions-1]
474
+ curr_res = resolution // 2**(self.num_resolutions-1)
475
+ self.z_shape = (1,z_channels,curr_res,curr_res)
476
+ print("Working with z of shape {} = {} dimensions.".format(
477
+ self.z_shape, np.prod(self.z_shape)))
478
+
479
+ # z to block_in
480
+ self.conv_in = torch.nn.Conv2d(z_channels,
481
+ block_in,
482
+ kernel_size=3,
483
+ stride=1,
484
+ padding=1,
485
+ padding_mode=padding_mode)
486
+
487
+ # middle
488
+ self.mid = nn.Module()
489
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
490
+ padding_mode=padding_mode,
491
+ out_channels=block_in,
492
+ temb_channels=self.temb_ch,
493
+ dropout=dropout)
494
+ self.mid.attn_1 = AttnBlock(block_in)
495
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
496
+ padding_mode=padding_mode,
497
+ out_channels=block_in,
498
+ temb_channels=self.temb_ch,
499
+ dropout=dropout)
500
+
501
+ # upsampling
502
+ self.up = nn.ModuleList()
503
+ for i_level in reversed(range(self.num_resolutions)):
504
+ block = nn.ModuleList()
505
+ attn = nn.ModuleList()
506
+ block_out = ch*ch_mult[i_level]
507
+ for i_block in range(self.num_res_blocks[i_level]+1):
508
+ block.append(ResnetBlock(in_channels=block_in,
509
+ padding_mode=padding_mode,
510
+ out_channels=block_out,
511
+ temb_channels=self.temb_ch,
512
+ dropout=dropout))
513
+ block_in = block_out
514
+ if curr_res in attn_resolutions:
515
+ attn.append(AttnBlock(block_in))
516
+ up = nn.Module()
517
+ up.block = block
518
+ up.attn = attn
519
+ if i_level != 0:
520
+ up.upsample = Upsample(block_in, resamp_with_conv, padding_mode)
521
+ curr_res = curr_res * 2
522
+ self.up.insert(0, up) # prepend to get consistent order
523
+
524
+ # end
525
+ self.norm_out = Normalize(block_in)
526
+ self.conv_out = torch.nn.Conv2d(block_in,
527
+ out_ch,
528
+ kernel_size=3,
529
+ stride=1,
530
+ padding=1,
531
+ padding_mode=padding_mode)
532
+
533
+ def forward(self, z):
534
+ #assert z.shape[1:] == self.z_shape[1:]
535
+ self.last_z_shape = z.shape
536
+
537
+ # timestep embedding
538
+ temb = None
539
+
540
+ # z to block_in
541
+ h = self.conv_in(z)
542
+
543
+ # middle
544
+ h = self.mid.block_1(h, temb)
545
+ h = self.mid.attn_1(h)
546
+ h = self.mid.block_2(h, temb)
547
+
548
+ # upsampling
549
+ for i_level in reversed(range(self.num_resolutions)):
550
+ for i_block in range(self.num_res_blocks[i_level]+1):
551
+ h = self.up[i_level].block[i_block](h, temb)
552
+ if len(self.up[i_level].attn) > 0:
553
+ h = self.up[i_level].attn[i_block](h)
554
+ if i_level != 0:
555
+ h = self.up[i_level].upsample(h)
556
+
557
+ # end
558
+ if self.give_pre_end:
559
+ return h
560
+
561
+ h = self.norm_out(h)
562
+ h = nonlinearity(h)
563
+ h = self.conv_out(h)
564
+ return h
565
+
566
+
567
+ class VUNet(nn.Module):
568
+ def __init__(self, *, ch, out_ch, padding_mode, ch_mult=(1,2,4,8), num_res_blocks,
569
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
570
+ in_channels, c_channels,
571
+ resolution, z_channels, use_timestep=False, **ignore_kwargs):
572
+ super().__init__()
573
+ self.ch = ch
574
+ self.temb_ch = self.ch*4
575
+ self.num_resolutions = len(ch_mult)
576
+ self.num_res_blocks = num_res_blocks
577
+ self.resolution = resolution
578
+
579
+ self.use_timestep = use_timestep
580
+ if self.use_timestep:
581
+ # timestep embedding
582
+ self.temb = nn.Module()
583
+ self.temb.dense = nn.ModuleList([
584
+ torch.nn.Linear(self.ch,
585
+ self.temb_ch),
586
+ torch.nn.Linear(self.temb_ch,
587
+ self.temb_ch),
588
+ ])
589
+
590
+ # downsampling
591
+ self.conv_in = torch.nn.Conv2d(c_channels,
592
+ self.ch,
593
+ kernel_size=3,
594
+ stride=1,
595
+ padding=1,
596
+ padding_mode=padding_mode)
597
+
598
+ curr_res = resolution
599
+ in_ch_mult = (1,)+tuple(ch_mult)
600
+ self.down = nn.ModuleList()
601
+ for i_level in range(self.num_resolutions):
602
+ block = nn.ModuleList()
603
+ attn = nn.ModuleList()
604
+ block_in = ch*in_ch_mult[i_level]
605
+ block_out = ch*ch_mult[i_level]
606
+ for i_block in range(self.num_res_blocks):
607
+ block.append(ResnetBlock(in_channels=block_in,
608
+ out_channels=block_out,
609
+ padding_mode=padding_mode,
610
+ temb_channels=self.temb_ch,
611
+ dropout=dropout))
612
+ block_in = block_out
613
+ if curr_res in attn_resolutions:
614
+ attn.append(AttnBlock(block_in))
615
+ down = nn.Module()
616
+ down.block = block
617
+ down.attn = attn
618
+ if i_level != self.num_resolutions-1:
619
+ down.downsample = Downsample(block_in, resamp_with_conv)
620
+ curr_res = curr_res // 2
621
+ self.down.append(down)
622
+
623
+ self.z_in = torch.nn.Conv2d(z_channels,
624
+ block_in,
625
+ kernel_size=1,
626
+ stride=1,
627
+ padding=0)
628
+ # middle
629
+ self.mid = nn.Module()
630
+ self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
631
+ out_channels=block_in,
632
+ padding_mode=padding_mode,
633
+ temb_channels=self.temb_ch,
634
+ dropout=dropout)
635
+ self.mid.attn_1 = AttnBlock(block_in)
636
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
637
+ out_channels=block_in,
638
+ padding_mode=padding_mode,
639
+ temb_channels=self.temb_ch,
640
+ dropout=dropout)
641
+
642
+ # upsampling
643
+ self.up = nn.ModuleList()
644
+ for i_level in reversed(range(self.num_resolutions)):
645
+ block = nn.ModuleList()
646
+ attn = nn.ModuleList()
647
+ block_out = ch*ch_mult[i_level]
648
+ skip_in = ch*ch_mult[i_level]
649
+ for i_block in range(self.num_res_blocks+1):
650
+ if i_block == self.num_res_blocks:
651
+ skip_in = ch*in_ch_mult[i_level]
652
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
653
+ out_channels=block_out,
654
+ padding_mode=padding_mode,
655
+ temb_channels=self.temb_ch,
656
+ dropout=dropout))
657
+ block_in = block_out
658
+ if curr_res in attn_resolutions:
659
+ attn.append(AttnBlock(block_in))
660
+ up = nn.Module()
661
+ up.block = block
662
+ up.attn = attn
663
+ if i_level != 0:
664
+ up.upsample = Upsample(block_in, resamp_with_conv, padding_mode)
665
+ curr_res = curr_res * 2
666
+ self.up.insert(0, up) # prepend to get consistent order
667
+
668
+ # end
669
+ self.norm_out = Normalize(block_in)
670
+ self.conv_out = torch.nn.Conv2d(block_in,
671
+ out_ch,
672
+ kernel_size=3,
673
+ stride=1,
674
+ padding=1,
675
+ padding_mode=padding_mode)
676
+
677
+
678
+ def forward(self, x, z):
679
+ #assert x.shape[2] == x.shape[3] == self.resolution
680
+
681
+ if self.use_timestep:
682
+ # timestep embedding
683
+ assert t is not None
684
+ temb = get_timestep_embedding(t, self.ch)
685
+ temb = self.temb.dense[0](temb)
686
+ temb = nonlinearity(temb)
687
+ temb = self.temb.dense[1](temb)
688
+ else:
689
+ temb = None
690
+
691
+ # downsampling
692
+ hs = [self.conv_in(x)]
693
+ for i_level in range(self.num_resolutions):
694
+ for i_block in range(self.num_res_blocks):
695
+ h = self.down[i_level].block[i_block](hs[-1], temb)
696
+ if len(self.down[i_level].attn) > 0:
697
+ h = self.down[i_level].attn[i_block](h)
698
+ hs.append(h)
699
+ if i_level != self.num_resolutions-1:
700
+ hs.append(self.down[i_level].downsample(hs[-1]))
701
+
702
+ # middle
703
+ h = hs[-1]
704
+ z = self.z_in(z)
705
+ h = torch.cat((h,z),dim=1)
706
+ h = self.mid.block_1(h, temb)
707
+ h = self.mid.attn_1(h)
708
+ h = self.mid.block_2(h, temb)
709
+
710
+ # upsampling
711
+ for i_level in reversed(range(self.num_resolutions)):
712
+ for i_block in range(self.num_res_blocks+1):
713
+ h = self.up[i_level].block[i_block](
714
+ torch.cat([h, hs.pop()], dim=1), temb)
715
+ if len(self.up[i_level].attn) > 0:
716
+ h = self.up[i_level].attn[i_block](h)
717
+ if i_level != 0:
718
+ h = self.up[i_level].upsample(h)
719
+
720
+ # end
721
+ h = self.norm_out(h)
722
+ h = nonlinearity(h)
723
+ h = self.conv_out(h)
724
+ return h
725
+
726
+
727
+ class SimpleDecoder(nn.Module):
728
+ def __init__(self, in_channels, out_channels, padding_mode, *args, **kwargs):
729
+ super().__init__()
730
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
731
+ ResnetBlock(in_channels=in_channels,
732
+ padding_mode=padding_mode,
733
+ out_channels=2 * in_channels,
734
+ temb_channels=0, dropout=0.0),
735
+ ResnetBlock(in_channels=2 * in_channels,
736
+ padding_mode=padding_mode,
737
+ out_channels=4 * in_channels,
738
+ temb_channels=0, dropout=0.0),
739
+ ResnetBlock(in_channels=4 * in_channels,
740
+ padding_mode=padding_mode,
741
+ out_channels=2 * in_channels,
742
+ temb_channels=0, dropout=0.0),
743
+ nn.Conv2d(2*in_channels, in_channels, 1),
744
+ Upsample(in_channels, with_conv=True, padding_mode=padding_mode)])
745
+ # end
746
+ self.norm_out = Normalize(in_channels)
747
+ self.conv_out = torch.nn.Conv2d(in_channels,
748
+ out_channels,
749
+ kernel_size=3,
750
+ stride=1,
751
+ padding=1,
752
+ padding_mode=padding_mode)
753
+
754
+ def forward(self, x):
755
+ for i, layer in enumerate(self.model):
756
+ if i in [1,2,3]:
757
+ x = layer(x, None)
758
+ else:
759
+ x = layer(x)
760
+
761
+ h = self.norm_out(x)
762
+ h = nonlinearity(h)
763
+ x = self.conv_out(h)
764
+ return x
765
+
766
+
767
+ class UpsampleDecoder(nn.Module):
768
+ def __init__(self, in_channels, out_channels, padding_mode, ch, num_res_blocks, resolution,
769
+ ch_mult=(2,2), dropout=0.0):
770
+ super().__init__()
771
+ # upsampling
772
+ self.temb_ch = 0
773
+ self.num_resolutions = len(ch_mult)
774
+ self.num_res_blocks = num_res_blocks
775
+ block_in = in_channels
776
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
777
+ self.res_blocks = nn.ModuleList()
778
+ self.upsample_blocks = nn.ModuleList()
779
+ for i_level in range(self.num_resolutions):
780
+ res_block = []
781
+ block_out = ch * ch_mult[i_level]
782
+ for i_block in range(self.num_res_blocks + 1):
783
+ res_block.append(ResnetBlock(in_channels=block_in,
784
+ out_channels=block_out,
785
+ padding_mode=padding_mode,
786
+ temb_channels=self.temb_ch,
787
+ dropout=dropout))
788
+ block_in = block_out
789
+ self.res_blocks.append(nn.ModuleList(res_block))
790
+ if i_level != self.num_resolutions - 1:
791
+ self.upsample_blocks.append(Upsample(block_in, True, padding_mode))
792
+ curr_res = curr_res * 2
793
+
794
+ # end
795
+ self.norm_out = Normalize(block_in)
796
+ self.conv_out = torch.nn.Conv2d(block_in,
797
+ out_channels,
798
+ kernel_size=3,
799
+ stride=1,
800
+ padding=1,
801
+ padding_mode=padding_mode)
802
+
803
+ def forward(self, x):
804
+ # upsampling
805
+ h = x
806
+ for k, i_level in enumerate(range(self.num_resolutions)):
807
+ for i_block in range(self.num_res_blocks + 1):
808
+ h = self.res_blocks[i_level][i_block](h, None)
809
+ if i_level != self.num_resolutions - 1:
810
+ h = self.upsample_blocks[k](h)
811
+ h = self.norm_out(h)
812
+ h = nonlinearity(h)
813
+ h = self.conv_out(h)
814
+ return h
815
+
ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch as th
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from ldm.modules.diffusionmodules.util import (
10
+ checkpoint,
11
+ conv_nd,
12
+ linear,
13
+ avg_pool_nd,
14
+ zero_module,
15
+ normalization,
16
+ timestep_embedding,
17
+ )
18
+ from ldm.modules.attention import SpatialTransformer
19
+ from ldm.util import exists
20
+
21
+
22
+ # dummy replace
23
+ def convert_module_to_f16(x):
24
+ pass
25
+
26
+ def convert_module_to_f32(x):
27
+ pass
28
+
29
+
30
+ ## go
31
+ class AttentionPool2d(nn.Module):
32
+ """
33
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ spacial_dim: int,
39
+ embed_dim: int,
40
+ num_heads_channels: int,
41
+ output_dim: int = None,
42
+ ):
43
+ super().__init__()
44
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
45
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
46
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
47
+ self.num_heads = embed_dim // num_heads_channels
48
+ self.attention = QKVAttention(self.num_heads)
49
+
50
+ def forward(self, x):
51
+ b, c, *_spatial = x.shape
52
+ x = x.reshape(b, c, -1) # NC(HW)
53
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
54
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
55
+ x = self.qkv_proj(x)
56
+ x = self.attention(x)
57
+ x = self.c_proj(x)
58
+ return x[:, :, 0]
59
+
60
+
61
+ class TimestepBlock(nn.Module):
62
+ """
63
+ Any module where forward() takes timestep embeddings as a second argument.
64
+ """
65
+
66
+ @abstractmethod
67
+ def forward(self, x, emb):
68
+ """
69
+ Apply the module to `x` given `emb` timestep embeddings.
70
+ """
71
+
72
+
73
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
74
+ """
75
+ A sequential module that passes timestep embeddings to the children that
76
+ support it as an extra input.
77
+ """
78
+
79
+ def forward(self, x, emb, context=None):
80
+ for layer in self:
81
+ if isinstance(layer, TimestepBlock):
82
+ x = layer(x, emb)
83
+ elif isinstance(layer, SpatialTransformer):
84
+ x = layer(x, context)
85
+ else:
86
+ x = layer(x)
87
+ return x
88
+
89
+
90
+ class Upsample(nn.Module):
91
+ """
92
+ An upsampling layer with an optional convolution.
93
+ :param channels: channels in the inputs and outputs.
94
+ :param use_conv: a bool determining if a convolution is applied.
95
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
96
+ upsampling occurs in the inner-two dimensions.
97
+ """
98
+
99
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
100
+ super().__init__()
101
+ self.channels = channels
102
+ self.out_channels = out_channels or channels
103
+ self.use_conv = use_conv
104
+ self.dims = dims
105
+ if use_conv:
106
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
107
+
108
+ def forward(self, x):
109
+ assert x.shape[1] == self.channels
110
+ if self.dims == 3:
111
+ x = F.interpolate(
112
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
113
+ )
114
+ else:
115
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
116
+ if self.use_conv:
117
+ x = self.conv(x)
118
+ return x
119
+
120
+ class TransposedUpsample(nn.Module):
121
+ 'Learned 2x upsampling without padding'
122
+ def __init__(self, channels, out_channels=None, ks=5):
123
+ super().__init__()
124
+ self.channels = channels
125
+ self.out_channels = out_channels or channels
126
+
127
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
128
+
129
+ def forward(self,x):
130
+ return self.up(x)
131
+
132
+
133
+ class Downsample(nn.Module):
134
+ """
135
+ A downsampling layer with an optional convolution.
136
+ :param channels: channels in the inputs and outputs.
137
+ :param use_conv: a bool determining if a convolution is applied.
138
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
139
+ downsampling occurs in the inner-two dimensions.
140
+ """
141
+
142
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
143
+ super().__init__()
144
+ self.channels = channels
145
+ self.out_channels = out_channels or channels
146
+ self.use_conv = use_conv
147
+ self.dims = dims
148
+ stride = 2 if dims != 3 else (1, 2, 2)
149
+ if use_conv:
150
+ self.op = conv_nd(
151
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
152
+ )
153
+ else:
154
+ assert self.channels == self.out_channels
155
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
156
+
157
+ def forward(self, x):
158
+ assert x.shape[1] == self.channels
159
+ return self.op(x)
160
+
161
+
162
+ class ResBlock(TimestepBlock):
163
+ """
164
+ A residual block that can optionally change the number of channels.
165
+ :param channels: the number of input channels.
166
+ :param emb_channels: the number of timestep embedding channels.
167
+ :param dropout: the rate of dropout.
168
+ :param out_channels: if specified, the number of out channels.
169
+ :param use_conv: if True and out_channels is specified, use a spatial
170
+ convolution instead of a smaller 1x1 convolution to change the
171
+ channels in the skip connection.
172
+ :param dims: determines if the signal is 1D, 2D, or 3D.
173
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
174
+ :param up: if True, use this block for upsampling.
175
+ :param down: if True, use this block for downsampling.
176
+ """
177
+
178
+ def __init__(
179
+ self,
180
+ channels,
181
+ emb_channels,
182
+ dropout,
183
+ out_channels=None,
184
+ use_conv=False,
185
+ use_scale_shift_norm=False,
186
+ dims=2,
187
+ use_checkpoint=False,
188
+ up=False,
189
+ down=False,
190
+ ):
191
+ super().__init__()
192
+ self.channels = channels
193
+ self.emb_channels = emb_channels
194
+ self.dropout = dropout
195
+ self.out_channels = out_channels or channels
196
+ self.use_conv = use_conv
197
+ self.use_checkpoint = use_checkpoint
198
+ self.use_scale_shift_norm = use_scale_shift_norm
199
+
200
+ self.in_layers = nn.Sequential(
201
+ normalization(channels),
202
+ nn.SiLU(),
203
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
204
+ )
205
+
206
+ self.updown = up or down
207
+
208
+ if up:
209
+ self.h_upd = Upsample(channels, False, dims)
210
+ self.x_upd = Upsample(channels, False, dims)
211
+ elif down:
212
+ self.h_upd = Downsample(channels, False, dims)
213
+ self.x_upd = Downsample(channels, False, dims)
214
+ else:
215
+ self.h_upd = self.x_upd = nn.Identity()
216
+
217
+ self.emb_layers = nn.Sequential(
218
+ nn.SiLU(),
219
+ linear(
220
+ emb_channels,
221
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
222
+ ),
223
+ )
224
+ self.out_layers = nn.Sequential(
225
+ normalization(self.out_channels),
226
+ nn.SiLU(),
227
+ nn.Dropout(p=dropout),
228
+ zero_module(
229
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
230
+ ),
231
+ )
232
+
233
+ if self.out_channels == channels:
234
+ self.skip_connection = nn.Identity()
235
+ elif use_conv:
236
+ self.skip_connection = conv_nd(
237
+ dims, channels, self.out_channels, 3, padding=1
238
+ )
239
+ else:
240
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
241
+
242
+ def forward(self, x, emb):
243
+ """
244
+ Apply the block to a Tensor, conditioned on a timestep embedding.
245
+ :param x: an [N x C x ...] Tensor of features.
246
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
247
+ :return: an [N x C x ...] Tensor of outputs.
248
+ """
249
+ return checkpoint(
250
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
251
+ )
252
+
253
+
254
+ def _forward(self, x, emb):
255
+ if self.updown:
256
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
257
+ h = in_rest(x)
258
+ h = self.h_upd(h)
259
+ x = self.x_upd(x)
260
+ h = in_conv(h)
261
+ else:
262
+ h = self.in_layers(x)
263
+ emb_out = self.emb_layers(emb).type(h.dtype)
264
+ while len(emb_out.shape) < len(h.shape):
265
+ emb_out = emb_out[..., None]
266
+ if self.use_scale_shift_norm:
267
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
268
+ scale, shift = th.chunk(emb_out, 2, dim=1)
269
+ h = out_norm(h) * (1 + scale) + shift
270
+ h = out_rest(h)
271
+ else:
272
+ h = h + emb_out
273
+ h = self.out_layers(h)
274
+ return self.skip_connection(x) + h
275
+
276
+
277
+ class AttentionBlock(nn.Module):
278
+ """
279
+ An attention block that allows spatial positions to attend to each other.
280
+ Originally ported from here, but adapted to the N-d case.
281
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
282
+ """
283
+
284
+ def __init__(
285
+ self,
286
+ channels,
287
+ num_heads=1,
288
+ num_head_channels=-1,
289
+ use_checkpoint=False,
290
+ use_new_attention_order=False,
291
+ ):
292
+ super().__init__()
293
+ self.channels = channels
294
+ if num_head_channels == -1:
295
+ self.num_heads = num_heads
296
+ else:
297
+ assert (
298
+ channels % num_head_channels == 0
299
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
300
+ self.num_heads = channels // num_head_channels
301
+ self.use_checkpoint = use_checkpoint
302
+ self.norm = normalization(channels)
303
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
304
+ if use_new_attention_order:
305
+ # split qkv before split heads
306
+ self.attention = QKVAttention(self.num_heads)
307
+ else:
308
+ # split heads before split qkv
309
+ self.attention = QKVAttentionLegacy(self.num_heads)
310
+
311
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
312
+
313
+ def forward(self, x):
314
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
315
+ #return pt_checkpoint(self._forward, x) # pytorch
316
+
317
+ def _forward(self, x):
318
+ b, c, *spatial = x.shape
319
+ x = x.reshape(b, c, -1)
320
+ qkv = self.qkv(self.norm(x))
321
+ h = self.attention(qkv)
322
+ h = self.proj_out(h)
323
+ return (x + h).reshape(b, c, *spatial)
324
+
325
+
326
+ def count_flops_attn(model, _x, y):
327
+ """
328
+ A counter for the `thop` package to count the operations in an
329
+ attention operation.
330
+ Meant to be used like:
331
+ macs, params = thop.profile(
332
+ model,
333
+ inputs=(inputs, timestamps),
334
+ custom_ops={QKVAttention: QKVAttention.count_flops},
335
+ )
336
+ """
337
+ b, c, *spatial = y[0].shape
338
+ num_spatial = int(np.prod(spatial))
339
+ # We perform two matmuls with the same number of ops.
340
+ # The first computes the weight matrix, the second computes
341
+ # the combination of the value vectors.
342
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
343
+ model.total_ops += th.DoubleTensor([matmul_ops])
344
+
345
+
346
+ class QKVAttentionLegacy(nn.Module):
347
+ """
348
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
349
+ """
350
+
351
+ def __init__(self, n_heads):
352
+ super().__init__()
353
+ self.n_heads = n_heads
354
+
355
+ def forward(self, qkv):
356
+ """
357
+ Apply QKV attention.
358
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
359
+ :return: an [N x (H * C) x T] tensor after attention.
360
+ """
361
+ bs, width, length = qkv.shape
362
+ assert width % (3 * self.n_heads) == 0
363
+ ch = width // (3 * self.n_heads)
364
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
365
+ scale = 1 / math.sqrt(math.sqrt(ch))
366
+ weight = th.einsum(
367
+ "bct,bcs->bts", q * scale, k * scale
368
+ ) # More stable with f16 than dividing afterwards
369
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
370
+ a = th.einsum("bts,bcs->bct", weight, v)
371
+ return a.reshape(bs, -1, length)
372
+
373
+ @staticmethod
374
+ def count_flops(model, _x, y):
375
+ return count_flops_attn(model, _x, y)
376
+
377
+
378
+ class QKVAttention(nn.Module):
379
+ """
380
+ A module which performs QKV attention and splits in a different order.
381
+ """
382
+
383
+ def __init__(self, n_heads):
384
+ super().__init__()
385
+ self.n_heads = n_heads
386
+
387
+ def forward(self, qkv):
388
+ """
389
+ Apply QKV attention.
390
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
391
+ :return: an [N x (H * C) x T] tensor after attention.
392
+ """
393
+ bs, width, length = qkv.shape
394
+ assert width % (3 * self.n_heads) == 0
395
+ ch = width // (3 * self.n_heads)
396
+ q, k, v = qkv.chunk(3, dim=1)
397
+ scale = 1 / math.sqrt(math.sqrt(ch))
398
+ weight = th.einsum(
399
+ "bct,bcs->bts",
400
+ (q * scale).view(bs * self.n_heads, ch, length),
401
+ (k * scale).view(bs * self.n_heads, ch, length),
402
+ ) # More stable with f16 than dividing afterwards
403
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
404
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
405
+ return a.reshape(bs, -1, length)
406
+
407
+ @staticmethod
408
+ def count_flops(model, _x, y):
409
+ return count_flops_attn(model, _x, y)
410
+
411
+
412
+ class UNetModel(nn.Module):
413
+ """
414
+ The full UNet model with attention and timestep embedding.
415
+ :param in_channels: channels in the input Tensor.
416
+ :param model_channels: base channel count for the model.
417
+ :param out_channels: channels in the output Tensor.
418
+ :param num_res_blocks: number of residual blocks per downsample.
419
+ :param attention_resolutions: a collection of downsample rates at which
420
+ attention will take place. May be a set, list, or tuple.
421
+ For example, if this contains 4, then at 4x downsampling, attention
422
+ will be used.
423
+ :param dropout: the dropout probability.
424
+ :param channel_mult: channel multiplier for each level of the UNet.
425
+ :param conv_resample: if True, use learned convolutions for upsampling and
426
+ downsampling.
427
+ :param dims: determines if the signal is 1D, 2D, or 3D.
428
+ :param num_classes: if specified (as an int), then this model will be
429
+ class-conditional with `num_classes` classes.
430
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
431
+ :param num_heads: the number of attention heads in each attention layer.
432
+ :param num_heads_channels: if specified, ignore num_heads and instead use
433
+ a fixed channel width per attention head.
434
+ :param num_heads_upsample: works with num_heads to set a different number
435
+ of heads for upsampling. Deprecated.
436
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
437
+ :param resblock_updown: use residual blocks for up/downsampling.
438
+ :param use_new_attention_order: use a different attention pattern for potentially
439
+ increased efficiency.
440
+ """
441
+
442
+ def __init__(
443
+ self,
444
+ image_size,
445
+ in_channels,
446
+ model_channels,
447
+ out_channels,
448
+ num_res_blocks,
449
+ attention_resolutions,
450
+ dropout=0,
451
+ channel_mult=(1, 2, 4, 8),
452
+ conv_resample=True,
453
+ dims=2,
454
+ num_classes=None,
455
+ use_checkpoint=False,
456
+ use_fp16=False,
457
+ use_bf16=False,
458
+ num_heads=-1,
459
+ num_head_channels=-1,
460
+ num_heads_upsample=-1,
461
+ use_scale_shift_norm=False,
462
+ resblock_updown=False,
463
+ use_new_attention_order=False,
464
+ use_spatial_transformer=False, # custom transformer support
465
+ transformer_depth=1, # custom transformer support
466
+ context_dim=None, # custom transformer support
467
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
468
+ legacy=True,
469
+ disable_self_attentions=None,
470
+ num_attention_blocks=None,
471
+ disable_middle_self_attn=False,
472
+ use_linear_in_transformer=False,
473
+ ):
474
+ super().__init__()
475
+ if use_spatial_transformer:
476
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
477
+
478
+ if context_dim is not None:
479
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
480
+ from omegaconf.listconfig import ListConfig
481
+ if type(context_dim) == ListConfig:
482
+ context_dim = list(context_dim)
483
+
484
+ if num_heads_upsample == -1:
485
+ num_heads_upsample = num_heads
486
+
487
+ if num_heads == -1:
488
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
489
+
490
+ if num_head_channels == -1:
491
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
492
+
493
+ self.image_size = image_size
494
+ self.in_channels = in_channels
495
+ self.model_channels = model_channels
496
+ self.out_channels = out_channels
497
+ if isinstance(num_res_blocks, int):
498
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
499
+ else:
500
+ if len(num_res_blocks) != len(channel_mult):
501
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
502
+ "as a list/tuple (per-level) with the same length as channel_mult")
503
+ self.num_res_blocks = num_res_blocks
504
+ if disable_self_attentions is not None:
505
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
506
+ assert len(disable_self_attentions) == len(channel_mult)
507
+ if num_attention_blocks is not None:
508
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
509
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
510
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
511
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
512
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
513
+ f"attention will still not be set.")
514
+
515
+ self.attention_resolutions = attention_resolutions
516
+ self.dropout = dropout
517
+ self.channel_mult = channel_mult
518
+ self.conv_resample = conv_resample
519
+ self.num_classes = num_classes
520
+ self.use_checkpoint = use_checkpoint
521
+ self.dtype = th.float16 if use_fp16 else th.float32
522
+ self.dtype = th.bfloat16 if use_bf16 else self.dtype
523
+ self.num_heads = num_heads
524
+ self.num_head_channels = num_head_channels
525
+ self.num_heads_upsample = num_heads_upsample
526
+ self.predict_codebook_ids = n_embed is not None
527
+
528
+ time_embed_dim = model_channels * 4
529
+ self.time_embed = nn.Sequential(
530
+ linear(model_channels, time_embed_dim),
531
+ nn.SiLU(),
532
+ linear(time_embed_dim, time_embed_dim),
533
+ )
534
+
535
+ if self.num_classes is not None:
536
+ if isinstance(self.num_classes, int):
537
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
538
+ elif self.num_classes == "continuous":
539
+ print("setting up linear c_adm embedding layer")
540
+ self.label_emb = nn.Linear(1, time_embed_dim)
541
+ else:
542
+ raise ValueError()
543
+
544
+ self.input_blocks = nn.ModuleList(
545
+ [
546
+ TimestepEmbedSequential(
547
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
548
+ )
549
+ ]
550
+ )
551
+ self._feature_size = model_channels
552
+ input_block_chans = [model_channels]
553
+ ch = model_channels
554
+ ds = 1
555
+ for level, mult in enumerate(channel_mult):
556
+ for nr in range(self.num_res_blocks[level]):
557
+ layers = [
558
+ ResBlock(
559
+ ch,
560
+ time_embed_dim,
561
+ dropout,
562
+ out_channels=mult * model_channels,
563
+ dims=dims,
564
+ use_checkpoint=use_checkpoint,
565
+ use_scale_shift_norm=use_scale_shift_norm,
566
+ )
567
+ ]
568
+ ch = mult * model_channels
569
+ if ds in attention_resolutions:
570
+ if num_head_channels == -1:
571
+ dim_head = ch // num_heads
572
+ else:
573
+ num_heads = ch // num_head_channels
574
+ dim_head = num_head_channels
575
+ if legacy:
576
+ #num_heads = 1
577
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
578
+ if exists(disable_self_attentions):
579
+ disabled_sa = disable_self_attentions[level]
580
+ else:
581
+ disabled_sa = False
582
+
583
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
584
+ layers.append(
585
+ AttentionBlock(
586
+ ch,
587
+ use_checkpoint=use_checkpoint,
588
+ num_heads=num_heads,
589
+ num_head_channels=dim_head,
590
+ use_new_attention_order=use_new_attention_order,
591
+ ) if not use_spatial_transformer else SpatialTransformer(
592
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
593
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
594
+ use_checkpoint=use_checkpoint
595
+ )
596
+ )
597
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
598
+ self._feature_size += ch
599
+ input_block_chans.append(ch)
600
+ if level != len(channel_mult) - 1:
601
+ out_ch = ch
602
+ self.input_blocks.append(
603
+ TimestepEmbedSequential(
604
+ ResBlock(
605
+ ch,
606
+ time_embed_dim,
607
+ dropout,
608
+ out_channels=out_ch,
609
+ dims=dims,
610
+ use_checkpoint=use_checkpoint,
611
+ use_scale_shift_norm=use_scale_shift_norm,
612
+ down=True,
613
+ )
614
+ if resblock_updown
615
+ else Downsample(
616
+ ch, conv_resample, dims=dims, out_channels=out_ch
617
+ )
618
+ )
619
+ )
620
+ ch = out_ch
621
+ input_block_chans.append(ch)
622
+ ds *= 2
623
+ self._feature_size += ch
624
+
625
+ if num_head_channels == -1:
626
+ dim_head = ch // num_heads
627
+ else:
628
+ num_heads = ch // num_head_channels
629
+ dim_head = num_head_channels
630
+ if legacy:
631
+ #num_heads = 1
632
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
633
+ self.middle_block = TimestepEmbedSequential(
634
+ ResBlock(
635
+ ch,
636
+ time_embed_dim,
637
+ dropout,
638
+ dims=dims,
639
+ use_checkpoint=use_checkpoint,
640
+ use_scale_shift_norm=use_scale_shift_norm,
641
+ ),
642
+ AttentionBlock(
643
+ ch,
644
+ use_checkpoint=use_checkpoint,
645
+ num_heads=num_heads,
646
+ num_head_channels=dim_head,
647
+ use_new_attention_order=use_new_attention_order,
648
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
649
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
650
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
651
+ use_checkpoint=use_checkpoint
652
+ ),
653
+ ResBlock(
654
+ ch,
655
+ time_embed_dim,
656
+ dropout,
657
+ dims=dims,
658
+ use_checkpoint=use_checkpoint,
659
+ use_scale_shift_norm=use_scale_shift_norm,
660
+ ),
661
+ )
662
+ self._feature_size += ch
663
+
664
+ self.output_blocks = nn.ModuleList([])
665
+ for level, mult in list(enumerate(channel_mult))[::-1]:
666
+ for i in range(self.num_res_blocks[level] + 1):
667
+ ich = input_block_chans.pop()
668
+ layers = [
669
+ ResBlock(
670
+ ch + ich,
671
+ time_embed_dim,
672
+ dropout,
673
+ out_channels=model_channels * mult,
674
+ dims=dims,
675
+ use_checkpoint=use_checkpoint,
676
+ use_scale_shift_norm=use_scale_shift_norm,
677
+ )
678
+ ]
679
+ ch = model_channels * mult
680
+ if ds in attention_resolutions:
681
+ if num_head_channels == -1:
682
+ dim_head = ch // num_heads
683
+ else:
684
+ num_heads = ch // num_head_channels
685
+ dim_head = num_head_channels
686
+ if legacy:
687
+ #num_heads = 1
688
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
689
+ if exists(disable_self_attentions):
690
+ disabled_sa = disable_self_attentions[level]
691
+ else:
692
+ disabled_sa = False
693
+
694
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
695
+ layers.append(
696
+ AttentionBlock(
697
+ ch,
698
+ use_checkpoint=use_checkpoint,
699
+ num_heads=num_heads_upsample,
700
+ num_head_channels=dim_head,
701
+ use_new_attention_order=use_new_attention_order,
702
+ ) if not use_spatial_transformer else SpatialTransformer(
703
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
704
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
705
+ use_checkpoint=use_checkpoint
706
+ )
707
+ )
708
+ if level and i == self.num_res_blocks[level]:
709
+ out_ch = ch
710
+ layers.append(
711
+ ResBlock(
712
+ ch,
713
+ time_embed_dim,
714
+ dropout,
715
+ out_channels=out_ch,
716
+ dims=dims,
717
+ use_checkpoint=use_checkpoint,
718
+ use_scale_shift_norm=use_scale_shift_norm,
719
+ up=True,
720
+ )
721
+ if resblock_updown
722
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
723
+ )
724
+ ds //= 2
725
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
726
+ self._feature_size += ch
727
+
728
+ self.out = nn.Sequential(
729
+ normalization(ch),
730
+ nn.SiLU(),
731
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
732
+ )
733
+ if self.predict_codebook_ids:
734
+ self.id_predictor = nn.Sequential(
735
+ normalization(ch),
736
+ conv_nd(dims, model_channels, n_embed, 1),
737
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
738
+ )
739
+
740
+ def convert_to_fp16(self):
741
+ """
742
+ Convert the torso of the model to float16.
743
+ """
744
+ self.input_blocks.apply(convert_module_to_f16)
745
+ self.middle_block.apply(convert_module_to_f16)
746
+ self.output_blocks.apply(convert_module_to_f16)
747
+
748
+ def convert_to_fp32(self):
749
+ """
750
+ Convert the torso of the model to float32.
751
+ """
752
+ self.input_blocks.apply(convert_module_to_f32)
753
+ self.middle_block.apply(convert_module_to_f32)
754
+ self.output_blocks.apply(convert_module_to_f32)
755
+
756
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
757
+ """
758
+ Apply the model to an input batch.
759
+ :param x: an [N x C x ...] Tensor of inputs.
760
+ :param timesteps: a 1-D batch of timesteps.
761
+ :param context: conditioning plugged in via crossattn
762
+ :param y: an [N] Tensor of labels, if class-conditional.
763
+ :return: an [N x C x ...] Tensor of outputs.
764
+ """
765
+ assert (y is not None) == (
766
+ self.num_classes is not None
767
+ ), "must specify y if and only if the model is class-conditional"
768
+ hs = []
769
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
770
+ emb = self.time_embed(t_emb)
771
+
772
+ if self.num_classes is not None:
773
+ assert y.shape[0] == x.shape[0]
774
+ emb = emb + self.label_emb(y)
775
+
776
+ h = x.type(self.dtype)
777
+ for module in self.input_blocks:
778
+ h = module(h, emb, context)
779
+ hs.append(h)
780
+ h = self.middle_block(h, emb, context)
781
+ for module in self.output_blocks:
782
+ h = th.cat([h, hs.pop()], dim=1)
783
+ h = module(h, emb, context)
784
+ h = h.type(x.dtype)
785
+ if self.predict_codebook_ids:
786
+ return self.id_predictor(h)
787
+ else:
788
+ return self.out(h)
ldm/modules/diffusionmodules/upscaling.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from functools import partial
5
+
6
+ from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
7
+ from ldm.util import default
8
+
9
+
10
+ class AbstractLowScaleModel(nn.Module):
11
+ # for concatenating a downsampled image to the latent representation
12
+ def __init__(self, noise_schedule_config=None):
13
+ super(AbstractLowScaleModel, self).__init__()
14
+ if noise_schedule_config is not None:
15
+ self.register_schedule(**noise_schedule_config)
16
+
17
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
18
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
19
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
20
+ cosine_s=cosine_s)
21
+ alphas = 1. - betas
22
+ alphas_cumprod = np.cumprod(alphas, axis=0)
23
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
24
+
25
+ timesteps, = betas.shape
26
+ self.num_timesteps = int(timesteps)
27
+ self.linear_start = linear_start
28
+ self.linear_end = linear_end
29
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
30
+
31
+ to_torch = partial(torch.tensor, dtype=torch.float32)
32
+
33
+ self.register_buffer('betas', to_torch(betas))
34
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
36
+
37
+ # calculations for diffusion q(x_t | x_{t-1}) and others
38
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
39
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
40
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
41
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
42
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
43
+
44
+ def q_sample(self, x_start, t, noise=None):
45
+ noise = default(noise, lambda: torch.randn_like(x_start))
46
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
47
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
48
+
49
+ def forward(self, x):
50
+ return x, None
51
+
52
+ def decode(self, x):
53
+ return x
54
+
55
+
56
+ class SimpleImageConcat(AbstractLowScaleModel):
57
+ # no noise level conditioning
58
+ def __init__(self):
59
+ super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
60
+ self.max_noise_level = 0
61
+
62
+ def forward(self, x):
63
+ # fix to constant noise level
64
+ return x, torch.zeros(x.shape[0], device=x.device).long()
65
+
66
+
67
+ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
68
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
69
+ super().__init__(noise_schedule_config=noise_schedule_config)
70
+ self.max_noise_level = max_noise_level
71
+
72
+ def forward(self, x, noise_level=None):
73
+ if noise_level is None:
74
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
75
+ else:
76
+ assert isinstance(noise_level, torch.Tensor)
77
+ z = self.q_sample(x, noise_level)
78
+ return z, noise_level
79
+
80
+
81
+
ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ from ldm.util import instantiate_from_config
19
+
20
+
21
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22
+ if schedule == "linear":
23
+ betas = (
24
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25
+ )
26
+
27
+ elif schedule == "cosine":
28
+ timesteps = (
29
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30
+ )
31
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
32
+ alphas = torch.cos(alphas).pow(2)
33
+ alphas = alphas / alphas[0]
34
+ betas = 1 - alphas[1:] / alphas[:-1]
35
+ betas = np.clip(betas, a_min=0, a_max=0.999)
36
+
37
+ elif schedule == "sqrt_linear":
38
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39
+ elif schedule == "sqrt":
40
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41
+ else:
42
+ raise ValueError(f"schedule '{schedule}' unknown.")
43
+ return betas.numpy()
44
+
45
+
46
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47
+ if ddim_discr_method == 'uniform':
48
+ c = num_ddpm_timesteps // num_ddim_timesteps
49
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50
+ elif ddim_discr_method == 'quad':
51
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52
+ else:
53
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54
+
55
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
57
+ steps_out = ddim_timesteps + 1
58
+ if verbose:
59
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
60
+ return steps_out
61
+
62
+
63
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64
+ # select alphas for computing the variance schedule
65
+ alphas = alphacums[ddim_timesteps]
66
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67
+
68
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
69
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70
+ if verbose:
71
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72
+ print(f'For the chosen value of eta, which is {eta}, '
73
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74
+ return sigmas, alphas, alphas_prev
75
+
76
+
77
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78
+ """
79
+ Create a beta schedule that discretizes the given alpha_t_bar function,
80
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
81
+ :param num_diffusion_timesteps: the number of betas to produce.
82
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83
+ produces the cumulative product of (1-beta) up to that
84
+ part of the diffusion process.
85
+ :param max_beta: the maximum beta to use; use values lower than 1 to
86
+ prevent singularities.
87
+ """
88
+ betas = []
89
+ for i in range(num_diffusion_timesteps):
90
+ t1 = i / num_diffusion_timesteps
91
+ t2 = (i + 1) / num_diffusion_timesteps
92
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93
+ return np.array(betas)
94
+
95
+
96
+ def extract_into_tensor(a, t, x_shape):
97
+ b, *_ = t.shape
98
+ out = a.gather(-1, t)
99
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100
+
101
+
102
+ def checkpoint(func, inputs, params, flag):
103
+ """
104
+ Evaluate a function without caching intermediate activations, allowing for
105
+ reduced memory at the expense of extra compute in the backward pass.
106
+ :param func: the function to evaluate.
107
+ :param inputs: the argument sequence to pass to `func`.
108
+ :param params: a sequence of parameters `func` depends on but does not
109
+ explicitly take as arguments.
110
+ :param flag: if False, disable gradient checkpointing.
111
+ """
112
+ if flag:
113
+ args = tuple(inputs) + tuple(params)
114
+ return CheckpointFunction.apply(func, len(inputs), *args)
115
+ else:
116
+ return func(*inputs)
117
+
118
+
119
+ class CheckpointFunction(torch.autograd.Function):
120
+ @staticmethod
121
+ def forward(ctx, run_function, length, *args):
122
+ ctx.run_function = run_function
123
+ ctx.input_tensors = list(args[:length])
124
+ ctx.input_params = list(args[length:])
125
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
126
+ "dtype": torch.get_autocast_gpu_dtype(),
127
+ "cache_enabled": torch.is_autocast_cache_enabled()}
128
+ with torch.no_grad():
129
+ output_tensors = ctx.run_function(*ctx.input_tensors)
130
+ return output_tensors
131
+
132
+ @staticmethod
133
+ def backward(ctx, *output_grads):
134
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
135
+ with torch.enable_grad(), \
136
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
137
+ # Fixes a bug where the first op in run_function modifies the
138
+ # Tensor storage in place, which is not allowed for detach()'d
139
+ # Tensors.
140
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
141
+ output_tensors = ctx.run_function(*shallow_copies)
142
+ input_grads = torch.autograd.grad(
143
+ output_tensors,
144
+ ctx.input_tensors + ctx.input_params,
145
+ output_grads,
146
+ allow_unused=True,
147
+ )
148
+ del ctx.input_tensors
149
+ del ctx.input_params
150
+ del output_tensors
151
+ return (None, None) + input_grads
152
+
153
+
154
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
155
+ """
156
+ Create sinusoidal timestep embeddings.
157
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
158
+ These may be fractional.
159
+ :param dim: the dimension of the output.
160
+ :param max_period: controls the minimum frequency of the embeddings.
161
+ :return: an [N x dim] Tensor of positional embeddings.
162
+ """
163
+ if not repeat_only:
164
+ half = dim // 2
165
+ freqs = torch.exp(
166
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
167
+ ).to(device=timesteps.device)
168
+ args = timesteps[:, None].float() * freqs[None]
169
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
170
+ if dim % 2:
171
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
172
+ else:
173
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
174
+ return embedding
175
+
176
+
177
+ def zero_module(module):
178
+ """
179
+ Zero out the parameters of a module and return it.
180
+ """
181
+ for p in module.parameters():
182
+ p.detach().zero_()
183
+ return module
184
+
185
+
186
+ def scale_module(module, scale):
187
+ """
188
+ Scale the parameters of a module and return it.
189
+ """
190
+ for p in module.parameters():
191
+ p.detach().mul_(scale)
192
+ return module
193
+
194
+
195
+ def mean_flat(tensor):
196
+ """
197
+ Take the mean over all non-batch dimensions.
198
+ """
199
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
200
+
201
+
202
+ def normalization(channels):
203
+ """
204
+ Make a standard normalization layer.
205
+ :param channels: number of input channels.
206
+ :return: an nn.Module for normalization.
207
+ """
208
+ return GroupNorm32(32, channels)
209
+
210
+
211
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
212
+ class SiLU(nn.Module):
213
+ def forward(self, x):
214
+ return x * torch.sigmoid(x)
215
+
216
+
217
+ class GroupNorm32(nn.GroupNorm):
218
+ def forward(self, x):
219
+ return super().forward(x.float()).type(x.dtype)
220
+
221
+ def conv_nd(dims, *args, **kwargs):
222
+ """
223
+ Create a 1D, 2D, or 3D convolution module.
224
+ """
225
+ if dims == 1:
226
+ return nn.Conv1d(*args, **kwargs)
227
+ elif dims == 2:
228
+ return nn.Conv2d(*args, **kwargs)
229
+ elif dims == 3:
230
+ return nn.Conv3d(*args, **kwargs)
231
+ raise ValueError(f"unsupported dimensions: {dims}")
232
+
233
+
234
+ def linear(*args, **kwargs):
235
+ """
236
+ Create a linear module.
237
+ """
238
+ return nn.Linear(*args, **kwargs)
239
+
240
+
241
+ def avg_pool_nd(dims, *args, **kwargs):
242
+ """
243
+ Create a 1D, 2D, or 3D average pooling module.
244
+ """
245
+ if dims == 1:
246
+ return nn.AvgPool1d(*args, **kwargs)
247
+ elif dims == 2:
248
+ return nn.AvgPool2d(*args, **kwargs)
249
+ elif dims == 3:
250
+ return nn.AvgPool3d(*args, **kwargs)
251
+ raise ValueError(f"unsupported dimensions: {dims}")
252
+
253
+
254
+ class HybridConditioner(nn.Module):
255
+
256
+ def __init__(self, c_concat_config, c_crossattn_config):
257
+ super().__init__()
258
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
259
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
260
+
261
+ def forward(self, c_concat, c_crossattn):
262
+ c_concat = self.concat_conditioner(c_concat)
263
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
264
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
265
+
266
+
267
+ def noise_like(shape, device, repeat=False):
268
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
269
+ noise = lambda: torch.randn(shape, device=device)
270
+ return repeat_noise() if repeat else noise()
ldm/modules/distributions/__init__.py ADDED
File without changes
ldm/modules/distributions/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (187 Bytes). View file
 
ldm/modules/distributions/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (183 Bytes). View file
 
ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (162 Bytes). View file
 
ldm/modules/distributions/__pycache__/distributions.cpython-311.pyc ADDED
Binary file (6.24 kB). View file