Update README.md
Browse files
README.md
CHANGED
|
@@ -51,36 +51,46 @@ At low diffusion times, the model exploits locality properties of natural images
|
|
| 51 |
## Usage
|
| 52 |
|
| 53 |
```python
|
| 54 |
-
from micro_diffusion.models.model import create_latent_diffusion
|
| 55 |
import torch
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
| 58 |
params = {
|
| 59 |
'latent_res': 64,
|
| 60 |
'in_channels': 4,
|
| 61 |
'pos_interp_scale': 2.0,
|
| 62 |
}
|
| 63 |
model = create_latent_diffusion(**params).to('cuda')
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
float_model_params = {
|
| 68 |
-
k
|
| 69 |
}
|
| 70 |
model.dit.load_state_dict(float_model_params)
|
| 71 |
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
prompts = [
|
| 77 |
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumet",
|
| 78 |
"A illustration from a graphic novel. A bustling city street under the shine of a full moon.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
]
|
| 80 |
-
|
| 81 |
-
model = model.eval()
|
| 82 |
-
gen_images = model.generate(prompt=prompts, num_inference_steps=30,
|
| 83 |
-
guidance_scale=5.0, seed=42)
|
| 84 |
```
|
| 85 |
|
| 86 |
## Citation
|
|
|
|
| 51 |
## Usage
|
| 52 |
|
| 53 |
```python
|
|
|
|
| 54 |
import torch
|
| 55 |
+
from micro_diffusion.models.model import create_latent_diffusion
|
| 56 |
+
from huggingface_hub import hf_hub_download
|
| 57 |
+
from safetensors import safe_open
|
| 58 |
|
| 59 |
+
# Init model
|
| 60 |
params = {
|
| 61 |
'latent_res': 64,
|
| 62 |
'in_channels': 4,
|
| 63 |
'pos_interp_scale': 2.0,
|
| 64 |
}
|
| 65 |
model = create_latent_diffusion(**params).to('cuda')
|
| 66 |
+
|
| 67 |
+
# Download weights from HF
|
| 68 |
+
model_dict_path = hf_hub_download(repo_id="giannisdaras/ambient-o", filename="model.safetensors")
|
| 69 |
+
model_dict = {}
|
| 70 |
+
with safe_open(model_dict_path, framework="pt", device="cpu") as f:
|
| 71 |
+
for key in f.keys():
|
| 72 |
+
model_dict[key] = f.get_tensor(key)
|
| 73 |
+
|
| 74 |
+
# Convert parameters to float32 + load
|
| 75 |
float_model_params = {
|
| 76 |
+
k: v.to(torch.float32) for k, v in model_dict.items()
|
| 77 |
}
|
| 78 |
model.dit.load_state_dict(float_model_params)
|
| 79 |
|
| 80 |
+
# Eval mode
|
| 81 |
+
model = model.eval()
|
| 82 |
|
| 83 |
+
# Generate images
|
|
|
|
|
|
|
| 84 |
prompts = [
|
| 85 |
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumet",
|
| 86 |
"A illustration from a graphic novel. A bustling city street under the shine of a full moon.",
|
| 87 |
+
"A giant cobra snake made from corn",
|
| 88 |
+
"A fierce garden gnome warrior, clad in armor crafted from leaves and bark, brandishes a tiny sword.",
|
| 89 |
+
"A capybara made of lego sitting in a realistic, natural field",
|
| 90 |
+
"a close-up of a fire spitting dragon, cinematic shot.",
|
| 91 |
+
"Panda mad scientist mixing sparkling chemicals, artstation"
|
| 92 |
]
|
| 93 |
+
images = model.generate(prompt=prompts, num_inference_steps=30, guidance_scale=5.0, seed=42)
|
|
|
|
|
|
|
|
|
|
| 94 |
```
|
| 95 |
|
| 96 |
## Citation
|